Commit 501b2e22 authored by Katarzyna Mitrus's avatar Katarzyna Mitrus Committed by Scott Cyphers

Add v1 version of Subtract with Numpy broadcasting as default (#3957)

* V1 version of Subtract with default Numpy autobcast

* Update op_v1_tbl.hpp with v1 version of Subtract

* Use v1 of Subtract in ONNX importer

* Add v1 namespace

* Update namspece

* Missing punctuation

* Add Subtract to opset0 downgrade

* Add Subtract to opset1 upgrade

* Add Subtract header to cpu emmiter

* Update serializer

* Add Subtract to opset_pass tests

* Use downgrade method

* Add get_version method

* Style apply

* Add v1 Substract to check opset1

* Add NGRAPH_API before class name

* Removed get_version method

* Separate cases for Subtract and Subtract_v1 in serializer

* Update op_version_tbl with v1 Subtract

* NUMPY autobcast for no args constructor

* Add Subtract_v1 to serializer
parent 95d072aa
...@@ -48,10 +48,8 @@ namespace ngraph ...@@ -48,10 +48,8 @@ namespace ngraph
{ {
inline NodeVector sub(const Node& node) inline NodeVector sub(const Node& node)
{ {
return {std::make_shared<ngraph::op::Subtract>( return {std::make_shared<ngraph::op::v1::Subtract>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -228,6 +228,7 @@ NGRAPH_OP(Squeeze, ngraph::op::v0, 0) ...@@ -228,6 +228,7 @@ NGRAPH_OP(Squeeze, ngraph::op::v0, 0)
NGRAPH_OP(StopGradient, ngraph::op::v0, 0) NGRAPH_OP(StopGradient, ngraph::op::v0, 0)
NGRAPH_OP(StridedSlice, ngraph::op::v1, 1) NGRAPH_OP(StridedSlice, ngraph::op::v1, 1)
NGRAPH_OP(Subtract, ngraph::op::v0, 0) NGRAPH_OP(Subtract, ngraph::op::v0, 0)
NGRAPH_OP(Subtract, ngraph::op::v1, 1)
NGRAPH_OP(Sum, ngraph::op::v0, 0) NGRAPH_OP(Sum, ngraph::op::v0, 0)
NGRAPH_OP(Tan, ngraph::op::v0, 0) NGRAPH_OP(Tan, ngraph::op::v0, 0)
NGRAPH_OP(Tanh, ngraph::op::v0, 0) NGRAPH_OP(Tanh, ngraph::op::v0, 0)
......
...@@ -20,9 +20,11 @@ ...@@ -20,9 +20,11 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Subtract::type_info; // ------------------------------- v0 ------------------------------------------
op::Subtract::Subtract(const Output<Node>& arg0, constexpr NodeTypeInfo op::v0::Subtract::type_info;
op::v0::Subtract::Subtract(const Output<Node>& arg0,
const Output<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast) const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast) : BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
...@@ -30,13 +32,13 @@ op::Subtract::Subtract(const Output<Node>& arg0, ...@@ -30,13 +32,13 @@ op::Subtract::Subtract(const Output<Node>& arg0,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Subtract::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Subtract::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Subtract>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v0::Subtract>(new_args.at(0), new_args.at(1), this->get_autob());
} }
void op::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::v0::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
if (get_autob().m_type != op::AutoBroadcastType::NONE) if (get_autob().m_type != op::AutoBroadcastType::NONE)
{ {
...@@ -54,5 +56,39 @@ void op::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec ...@@ -54,5 +56,39 @@ void op::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
shared_ptr<ngraph::Node> ngraph::operator-(const Output<Node> arg0, const Output<Node> arg1) shared_ptr<ngraph::Node> ngraph::operator-(const Output<Node> arg0, const Output<Node> arg1)
{ {
return make_shared<ngraph::op::Subtract>(arg0, arg1); return make_shared<op::v0::Subtract>(arg0, arg1);
}
// ------------------------------- v1 ------------------------------------------
constexpr NodeTypeInfo op::v1::Subtract::type_info;
op::v1::Subtract::Subtract(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v1::Subtract::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v1::Subtract>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::v1::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto x = input_value(0);
auto y = input_value(1);
adjoints.add_delta(x, delta);
adjoints.add_delta(y, -delta);
} }
...@@ -50,9 +50,41 @@ namespace ngraph ...@@ -50,9 +50,41 @@ namespace ngraph
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
}; };
} // namespace v0
namespace v1
{
/// \brief Elementwise subtraction operation.
class NGRAPH_API Subtract : public util::BinaryElementwiseArithmetic
{
public:
static constexpr NodeTypeInfo type_info{"Subtract", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Subtract()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
} }
/// \brief Constructs a subtraction operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Subtract(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v1
using v0::Subtract; using v0::Subtract;
} } // namespace op
std::shared_ptr<ngraph::Node> operator-(const Output<ngraph::Node> arg0, std::shared_ptr<ngraph::Node> operator-(const Output<ngraph::Node> arg0,
const Output<ngraph::Node> arg1); const Output<ngraph::Node> arg1);
} } // namespace ngraph
...@@ -146,7 +146,7 @@ NGRAPH_OP(Split, ngraph::op::v0) ...@@ -146,7 +146,7 @@ NGRAPH_OP(Split, ngraph::op::v0)
NGRAPH_OP(SquaredDifference, ngraph::op::v0) NGRAPH_OP(SquaredDifference, ngraph::op::v0)
NGRAPH_OP(Squeeze, ngraph::op::v0) NGRAPH_OP(Squeeze, ngraph::op::v0)
NGRAPH_OP(StridedSlice, ngraph::op::v1) NGRAPH_OP(StridedSlice, ngraph::op::v1)
NGRAPH_OP(Subtract, ngraph::op::v0) NGRAPH_OP(Subtract, ngraph::op::v1)
NGRAPH_OP(Tan, ngraph::op::v0) NGRAPH_OP(Tan, ngraph::op::v0)
NGRAPH_OP(Tanh, ngraph::op::v0) NGRAPH_OP(Tanh, ngraph::op::v0)
NGRAPH_OP(TensorIterator, ngraph::op::v0) NGRAPH_OP(TensorIterator, ngraph::op::v0)
......
...@@ -561,6 +561,12 @@ namespace ...@@ -561,6 +561,12 @@ namespace
return true; return true;
} }
bool op_cast(shared_ptr<op::v1::Subtract> node)
{
op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
return true;
}
bool op_cast(shared_ptr<op::v1::ReduceSum> node) bool op_cast(shared_ptr<op::v1::ReduceSum> node)
{ {
auto replacement_node = auto replacement_node =
......
...@@ -484,6 +484,12 @@ namespace ...@@ -484,6 +484,12 @@ namespace
return true; return true;
} }
bool op_cast(shared_ptr<op::Subtract> node)
{
op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
return true;
}
bool op_cast(shared_ptr<op::Sum> node) bool op_cast(shared_ptr<op::Sum> node)
{ {
bool keep_dims = false; bool keep_dims = false;
......
...@@ -2676,6 +2676,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2676,6 +2676,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast")); args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break; break;
} }
case OP_TYPEID::Subtract_v1:
{
node = make_shared<op::v1::Subtract>(
args[0],
args[1],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
break;
}
case OP_TYPEID::ReduceSum_v1: case OP_TYPEID::ReduceSum_v1:
{ {
auto keep_dims = node_js.at("keep_dims").get<bool>(); auto keep_dims = node_js.at("keep_dims").get<bool>();
...@@ -4295,6 +4303,15 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -4295,6 +4303,15 @@ json JSONSerializer::serialize_node(const Node& n)
} }
break; break;
} }
case OP_TYPEID::Subtract_v1:
{
auto tmp = static_cast<const op::v1::Subtract*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
break;
}
case OP_TYPEID::Sum: { break; case OP_TYPEID::Sum: { break;
} }
case OP_TYPEID::ReduceSum_v1: case OP_TYPEID::ReduceSum_v1:
......
...@@ -131,7 +131,7 @@ TEST(opset, check_opset1) ...@@ -131,7 +131,7 @@ TEST(opset, check_opset1)
CHECK_OPSET(op::v0::SquaredDifference, opset1::SquaredDifference) CHECK_OPSET(op::v0::SquaredDifference, opset1::SquaredDifference)
CHECK_OPSET(op::v0::Squeeze, opset1::Squeeze) CHECK_OPSET(op::v0::Squeeze, opset1::Squeeze)
CHECK_OPSET(op::v1::StridedSlice, opset1::StridedSlice) CHECK_OPSET(op::v1::StridedSlice, opset1::StridedSlice)
CHECK_OPSET(op::v0::Subtract, opset1::Subtract) CHECK_OPSET(op::v1::Subtract, opset1::Subtract)
CHECK_OPSET(op::v0::Tan, opset1::Tan) CHECK_OPSET(op::v0::Tan, opset1::Tan)
CHECK_OPSET(op::v0::Tanh, opset1::Tanh) CHECK_OPSET(op::v0::Tanh, opset1::Tanh)
CHECK_OPSET(op::v0::TensorIterator, opset1::TensorIterator) CHECK_OPSET(op::v0::TensorIterator, opset1::TensorIterator)
......
...@@ -262,3 +262,13 @@ TEST(opset_transform, opset1_power_upgrade_pass) ...@@ -262,3 +262,13 @@ TEST(opset_transform, opset1_power_upgrade_pass)
{ {
test_opset1_arithmetic_upgrade_pass<op::v0::Power, op::v1::Power>(); test_opset1_arithmetic_upgrade_pass<op::v0::Power, op::v1::Power>();
} }
TEST(opset_transform, opset0_subtract_downgrade_pass)
{
test_opset0_arithmetic_downgrade_pass<op::v0::Subtract, op::v1::Subtract>();
}
TEST(opset_transform, opset1_subtract_upgrade_pass)
{
test_opset1_arithmetic_upgrade_pass<op::v0::Subtract, op::v1::Subtract>();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment