Commit 501b2e22 authored by Katarzyna Mitrus's avatar Katarzyna Mitrus Committed by Scott Cyphers

Add v1 version of Subtract with Numpy broadcasting as default (#3957)

* V1 version of Subtract with default Numpy autobcast

* Update op_v1_tbl.hpp with v1 version of Subtract

* Use v1 of Subtract in ONNX importer

* Add v1 namespace

* Update namspece

* Missing punctuation

* Add Subtract to opset0 downgrade

* Add Subtract to opset1 upgrade

* Add Subtract header to cpu emmiter

* Update serializer

* Add Subtract to opset_pass tests

* Use downgrade method

* Add get_version method

* Style apply

* Add v1 Substract to check opset1

* Add NGRAPH_API before class name

* Removed get_version method

* Separate cases for Subtract and Subtract_v1 in serializer

* Update op_version_tbl with v1 Subtract

* NUMPY autobcast for no args constructor

* Add Subtract_v1 to serializer
parent 95d072aa
......@@ -48,10 +48,8 @@ namespace ngraph
{
inline NodeVector sub(const Node& node)
{
return {std::make_shared<ngraph::op::Subtract>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
return {std::make_shared<ngraph::op::v1::Subtract>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))};
}
} // namespace set_1
......
......@@ -228,6 +228,7 @@ NGRAPH_OP(Squeeze, ngraph::op::v0, 0)
NGRAPH_OP(StopGradient, ngraph::op::v0, 0)
NGRAPH_OP(StridedSlice, ngraph::op::v1, 1)
NGRAPH_OP(Subtract, ngraph::op::v0, 0)
NGRAPH_OP(Subtract, ngraph::op::v1, 1)
NGRAPH_OP(Sum, ngraph::op::v0, 0)
NGRAPH_OP(Tan, ngraph::op::v0, 0)
NGRAPH_OP(Tanh, ngraph::op::v0, 0)
......
......@@ -20,23 +20,25 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Subtract::type_info;
// ------------------------------- v0 ------------------------------------------
op::Subtract::Subtract(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
constexpr NodeTypeInfo op::v0::Subtract::type_info;
op::v0::Subtract::Subtract(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Subtract::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::Subtract::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Subtract>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<op::v0::Subtract>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
void op::v0::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
......@@ -54,5 +56,39 @@ void op::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
shared_ptr<ngraph::Node> ngraph::operator-(const Output<Node> arg0, const Output<Node> arg1)
{
return make_shared<ngraph::op::Subtract>(arg0, arg1);
return make_shared<op::v0::Subtract>(arg0, arg1);
}
// ------------------------------- v1 ------------------------------------------
constexpr NodeTypeInfo op::v1::Subtract::type_info;
op::v1::Subtract::Subtract(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v1::Subtract::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v1::Subtract>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::v1::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto x = input_value(0);
auto y = input_value(1);
adjoints.add_delta(x, delta);
adjoints.add_delta(y, -delta);
}
......@@ -50,9 +50,41 @@ namespace ngraph
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
}
} // namespace v0
namespace v1
{
/// \brief Elementwise subtraction operation.
class NGRAPH_API Subtract : public util::BinaryElementwiseArithmetic
{
public:
static constexpr NodeTypeInfo type_info{"Subtract", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Subtract()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
{
}
/// \brief Constructs a subtraction operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Subtract(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v1
using v0::Subtract;
}
} // namespace op
std::shared_ptr<ngraph::Node> operator-(const Output<ngraph::Node> arg0,
const Output<ngraph::Node> arg1);
}
} // namespace ngraph
......@@ -146,7 +146,7 @@ NGRAPH_OP(Split, ngraph::op::v0)
NGRAPH_OP(SquaredDifference, ngraph::op::v0)
NGRAPH_OP(Squeeze, ngraph::op::v0)
NGRAPH_OP(StridedSlice, ngraph::op::v1)
NGRAPH_OP(Subtract, ngraph::op::v0)
NGRAPH_OP(Subtract, ngraph::op::v1)
NGRAPH_OP(Tan, ngraph::op::v0)
NGRAPH_OP(Tanh, ngraph::op::v0)
NGRAPH_OP(TensorIterator, ngraph::op::v0)
......@@ -155,4 +155,4 @@ NGRAPH_OP(TopK, ngraph::op::v1)
NGRAPH_OP(Transpose, ngraph::op::v0)
NGRAPH_OP(Unsqueeze, ngraph::op::v0)
NGRAPH_OP(VariadicSplit, ngraph::op::v1)
NGRAPH_OP(Xor, ngraph::op::v0)
\ No newline at end of file
NGRAPH_OP(Xor, ngraph::op::v0)
......@@ -561,6 +561,12 @@ namespace
return true;
}
bool op_cast(shared_ptr<op::v1::Subtract> node)
{
op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
return true;
}
bool op_cast(shared_ptr<op::v1::ReduceSum> node)
{
auto replacement_node =
......
......@@ -484,6 +484,12 @@ namespace
return true;
}
bool op_cast(shared_ptr<op::Subtract> node)
{
op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
return true;
}
bool op_cast(shared_ptr<op::Sum> node)
{
bool keep_dims = false;
......
......@@ -2676,6 +2676,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
case OP_TYPEID::Subtract_v1:
{
node = make_shared<op::v1::Subtract>(
args[0],
args[1],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
break;
}
case OP_TYPEID::ReduceSum_v1:
{
auto keep_dims = node_js.at("keep_dims").get<bool>();
......@@ -4295,6 +4303,15 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::Subtract_v1:
{
auto tmp = static_cast<const op::v1::Subtract*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
break;
}
case OP_TYPEID::Sum: { break;
}
case OP_TYPEID::ReduceSum_v1:
......
......@@ -131,7 +131,7 @@ TEST(opset, check_opset1)
CHECK_OPSET(op::v0::SquaredDifference, opset1::SquaredDifference)
CHECK_OPSET(op::v0::Squeeze, opset1::Squeeze)
CHECK_OPSET(op::v1::StridedSlice, opset1::StridedSlice)
CHECK_OPSET(op::v0::Subtract, opset1::Subtract)
CHECK_OPSET(op::v1::Subtract, opset1::Subtract)
CHECK_OPSET(op::v0::Tan, opset1::Tan)
CHECK_OPSET(op::v0::Tanh, opset1::Tanh)
CHECK_OPSET(op::v0::TensorIterator, opset1::TensorIterator)
......
......@@ -262,3 +262,13 @@ TEST(opset_transform, opset1_power_upgrade_pass)
{
test_opset1_arithmetic_upgrade_pass<op::v0::Power, op::v1::Power>();
}
TEST(opset_transform, opset0_subtract_downgrade_pass)
{
test_opset0_arithmetic_downgrade_pass<op::v0::Subtract, op::v1::Subtract>();
}
TEST(opset_transform, opset1_subtract_upgrade_pass)
{
test_opset1_arithmetic_upgrade_pass<op::v0::Subtract, op::v1::Subtract>();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment