Unverified Commit 46199d5f authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Rewrite of is_functionally_identical behavior (#366)

* change default is_functionally_identical to return false so if an op forgets to override it gets a behavior that might be slower to compile but it will at least work
parent 9d0d7a7c
...@@ -337,6 +337,11 @@ bool Node::has_same_type(std::shared_ptr<const Node> node) const ...@@ -337,6 +337,11 @@ bool Node::has_same_type(std::shared_ptr<const Node> node) const
} }
bool Node::is_functionally_identical(const Node& other) const bool Node::is_functionally_identical(const Node& other) const
{
return false;
}
bool Node::test_identical(const Node& other) const
{ {
bool rc = true; bool rc = true;
if (this->description() == other.description()) if (this->description() == other.description())
......
...@@ -171,6 +171,7 @@ namespace ngraph ...@@ -171,6 +171,7 @@ namespace ngraph
protected: protected:
void add_output(const element::Type& element_type, const Shape& shape); void add_output(const element::Type& element_type, const Shape& shape);
void assert_argument_list_equivalency(const Nodes& b); void assert_argument_list_equivalency(const Nodes& b);
bool test_identical(const Node&) const;
std::string m_node_type; std::string m_node_type;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
......
...@@ -23,3 +23,8 @@ void ngraph::op::Abs::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -23,3 +23,8 @@ void ngraph::op::Abs::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta * std::make_shared<op::Sign>(x)); adjoints.add_delta(x, delta * std::make_shared<op::Sign>(x));
} }
bool ngraph::op::Abs::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -53,6 +53,7 @@ namespace ngraph ...@@ -53,6 +53,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Abs>(new_args.at(0)); return std::make_shared<Abs>(new_args.at(0));
} }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -53,6 +53,10 @@ namespace ngraph ...@@ -53,6 +53,10 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Acos>(new_args.at(0)); return std::make_shared<Acos>(new_args.at(0));
} }
bool is_functionally_identical(const Node& other) const override
{
return test_identical(other);
}
}; };
} }
} }
...@@ -23,3 +23,8 @@ void ngraph::op::Add::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -23,3 +23,8 @@ void ngraph::op::Add::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta); adjoints.add_delta(x, delta);
adjoints.add_delta(y, delta); adjoints.add_delta(y, delta);
} }
bool ngraph::op::Add::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -55,6 +55,7 @@ namespace ngraph ...@@ -55,6 +55,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Add>(new_args.at(0), new_args.at(1)); return std::make_shared<Add>(new_args.at(0), new_args.at(1));
} }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -53,6 +53,10 @@ namespace ngraph ...@@ -53,6 +53,10 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Asin>(new_args.at(0)); return std::make_shared<Asin>(new_args.at(0));
} }
bool is_functionally_identical(const Node& other) const override
{
return test_identical(other);
}
}; };
} }
} }
...@@ -53,6 +53,10 @@ namespace ngraph ...@@ -53,6 +53,10 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Atan>(new_args.at(0)); return std::make_shared<Atan>(new_args.at(0));
} }
bool is_functionally_identical(const Node& other) const override
{
return test_identical(other);
}
}; };
} }
} }
...@@ -49,7 +49,7 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -49,7 +49,7 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints,
bool op::Broadcast::is_functionally_identical(const Node& other) const bool op::Broadcast::is_functionally_identical(const Node& other) const
{ {
bool rc = true; bool rc = true;
if (Node::is_functionally_identical(other)) if (Node::test_identical(other))
{ {
const Broadcast& obj = dynamic_cast<const Broadcast&>(other); const Broadcast& obj = dynamic_cast<const Broadcast&>(other);
rc &= m_shape == obj.m_shape; rc &= m_shape == obj.m_shape;
......
...@@ -51,6 +51,10 @@ namespace ngraph ...@@ -51,6 +51,10 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Ceiling>(new_args.at(0)); return std::make_shared<Ceiling>(new_args.at(0));
} }
bool is_functionally_identical(const Node& other) const override
{
return test_identical(other);
}
}; };
} }
} }
...@@ -106,7 +106,7 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const std::shar ...@@ -106,7 +106,7 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const std::shar
bool op::Concat::is_functionally_identical(const Node& other) const bool op::Concat::is_functionally_identical(const Node& other) const
{ {
bool rc = true; bool rc = true;
if (Node::is_functionally_identical(other)) if (Node::test_identical(other))
{ {
const Concat& concat = dynamic_cast<const Concat&>(other); const Concat& concat = dynamic_cast<const Concat&>(other);
rc &= m_concatenation_axis == concat.m_concatenation_axis; rc &= m_concatenation_axis == concat.m_concatenation_axis;
......
...@@ -40,7 +40,7 @@ void ngraph::op::Convert::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -40,7 +40,7 @@ void ngraph::op::Convert::generate_adjoints(autodiff::Adjoints& adjoints,
bool op::Convert::is_functionally_identical(const Node& other) const bool op::Convert::is_functionally_identical(const Node& other) const
{ {
bool rc = true; bool rc = true;
if (Node::is_functionally_identical(other)) if (Node::test_identical(other))
{ {
const Convert& obj = dynamic_cast<const Convert&>(other); const Convert& obj = dynamic_cast<const Convert&>(other);
rc &= m_element_type == obj.m_element_type; rc &= m_element_type == obj.m_element_type;
......
...@@ -294,7 +294,7 @@ std::shared_ptr<Node> ...@@ -294,7 +294,7 @@ std::shared_ptr<Node>
bool op::Convolution::is_functionally_identical(const Node& other) const bool op::Convolution::is_functionally_identical(const Node& other) const
{ {
bool rc = true; bool rc = true;
if (Node::is_functionally_identical(other)) if (Node::test_identical(other))
{ {
const Convolution& rhs = dynamic_cast<const Convolution&>(other); const Convolution& rhs = dynamic_cast<const Convolution&>(other);
rc &= m_window_movement_strides == rhs.m_window_movement_strides; rc &= m_window_movement_strides == rhs.m_window_movement_strides;
......
...@@ -24,3 +24,8 @@ void ngraph::op::Cos::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -24,3 +24,8 @@ void ngraph::op::Cos::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, -delta * (std::make_shared<op::Sin>(x))); adjoints.add_delta(x, -delta * (std::make_shared<op::Sin>(x)));
} }
bool ngraph::op::Cos::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -51,6 +51,7 @@ namespace ngraph ...@@ -51,6 +51,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Cos>(new_args.at(0)); return std::make_shared<Cos>(new_args.at(0));
} }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -23,3 +23,8 @@ void ngraph::op::Cosh::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -23,3 +23,8 @@ void ngraph::op::Cosh::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta * (std::make_shared<op::Sinh>(x))); adjoints.add_delta(x, delta * (std::make_shared<op::Sinh>(x)));
} }
bool ngraph::op::Cosh::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -51,6 +51,7 @@ namespace ngraph ...@@ -51,6 +51,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Cosh>(new_args.at(0)); return std::make_shared<Cosh>(new_args.at(0));
} }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -25,3 +25,8 @@ void ngraph::op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -25,3 +25,8 @@ void ngraph::op::Divide::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta * shared_from_this() / x); adjoints.add_delta(x, delta * shared_from_this() / x);
adjoints.add_delta(y, -delta * shared_from_this() / y); adjoints.add_delta(y, -delta * shared_from_this() / y);
} }
bool ngraph::op::Divide::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -56,6 +56,7 @@ namespace ngraph ...@@ -56,6 +56,7 @@ namespace ngraph
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override; const std::shared_ptr<Node>& delta) override;
bool is_functionally_identical(const Node&) const override;
}; };
} }
inline std::shared_ptr<ngraph::Node> operator/(const std::shared_ptr<ngraph::Node> arg0, inline std::shared_ptr<ngraph::Node> operator/(const std::shared_ptr<ngraph::Node> arg0,
......
...@@ -147,7 +147,7 @@ void op::Dot::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ ...@@ -147,7 +147,7 @@ void op::Dot::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_
bool op::Dot::is_functionally_identical(const Node& other) const bool op::Dot::is_functionally_identical(const Node& other) const
{ {
bool rc = true; bool rc = true;
if (Node::is_functionally_identical(other)) if (Node::test_identical(other))
{ {
const Dot& rhs = dynamic_cast<const Dot&>(other); const Dot& rhs = dynamic_cast<const Dot&>(other);
rc &= m_reduction_axes_count == rhs.m_reduction_axes_count; rc &= m_reduction_axes_count == rhs.m_reduction_axes_count;
......
...@@ -53,6 +53,10 @@ namespace ngraph ...@@ -53,6 +53,10 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Equal>(new_args.at(0), new_args.at(1)); return std::make_shared<Equal>(new_args.at(0), new_args.at(1));
} }
bool is_functionally_identical(const Node& other) const override
{
return test_identical(other);
}
}; };
} }
} }
...@@ -22,3 +22,8 @@ void ngraph::op::Exp::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -22,3 +22,8 @@ void ngraph::op::Exp::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta * shared_from_this()); adjoints.add_delta(x, delta * shared_from_this());
} }
bool ngraph::op::Exp::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -54,6 +54,7 @@ namespace ngraph ...@@ -54,6 +54,7 @@ namespace ngraph
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override; const std::shared_ptr<Node>& delta) override;
bool is_functionally_identical(const Node&) const override;
}; };
} }
} }
...@@ -51,6 +51,10 @@ namespace ngraph ...@@ -51,6 +51,10 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Floor>(new_args.at(0)); return std::make_shared<Floor>(new_args.at(0));
} }
bool is_functionally_identical(const Node& other) const override
{
return test_identical(other);
}
}; };
} }
} }
...@@ -47,8 +47,3 @@ op::FunctionCall::FunctionCall(std::shared_ptr<Function> function, ...@@ -47,8 +47,3 @@ op::FunctionCall::FunctionCall(std::shared_ptr<Function> function,
add_output(function->get_output_element_type(i), function->get_output_shape(i)); add_output(function->get_output_element_type(i), function->get_output_shape(i));
} }
} }
bool op::FunctionCall::is_functionally_identical(const Node&) const
{
return false;
}
...@@ -55,8 +55,6 @@ namespace ngraph ...@@ -55,8 +55,6 @@ namespace ngraph
return std::make_shared<FunctionCall>(m_function, new_args); return std::make_shared<FunctionCall>(m_function, new_args);
} }
/// \return The function to be called.
bool is_functionally_identical(const Node&) const override;
/// \return A singleton vector containing the function to be called. /// \return A singleton vector containing the function to be called.
std::vector<std::shared_ptr<Function>> get_functions() const override std::vector<std::shared_ptr<Function>> get_functions() const override
{ {
......
...@@ -30,8 +30,3 @@ op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t ...@@ -30,8 +30,3 @@ op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t
set_value_type_checked(arg->get_output_element_type(n), arg->get_output_shape(n)); set_value_type_checked(arg->get_output_element_type(n), arg->get_output_shape(n));
} }
bool op::GetOutputElement::is_functionally_identical(const Node& other) const
{
return false;
}
...@@ -58,8 +58,6 @@ namespace ngraph ...@@ -58,8 +58,6 @@ namespace ngraph
/// \return The index of the tuple element to get. /// \return The index of the tuple element to get.
size_t get_n() const { return m_n; } size_t get_n() const { return m_n; }
bool is_functionally_identical(const Node&) const override;
protected: protected:
size_t m_n; size_t m_n;
}; };
......
...@@ -53,6 +53,10 @@ namespace ngraph ...@@ -53,6 +53,10 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Greater>(new_args.at(0), new_args.at(1)); return std::make_shared<Greater>(new_args.at(0), new_args.at(1));
} }
bool is_functionally_identical(const Node& other) const override
{
return test_identical(other);
}
}; };
} }
} }
...@@ -53,6 +53,10 @@ namespace ngraph ...@@ -53,6 +53,10 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<GreaterEq>(new_args.at(0), new_args.at(1)); return std::make_shared<GreaterEq>(new_args.at(0), new_args.at(1));
} }
bool is_functionally_identical(const Node& other) const override
{
return test_identical(other);
}
}; };
} }
} }
...@@ -53,6 +53,10 @@ namespace ngraph ...@@ -53,6 +53,10 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Less>(new_args.at(0), new_args.at(1)); return std::make_shared<Less>(new_args.at(0), new_args.at(1));
} }
bool is_functionally_identical(const Node& other) const override
{
return test_identical(other);
}
}; };
} }
} }
...@@ -53,6 +53,10 @@ namespace ngraph ...@@ -53,6 +53,10 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<LessEq>(new_args.at(0), new_args.at(1)); return std::make_shared<LessEq>(new_args.at(0), new_args.at(1));
} }
bool is_functionally_identical(const Node& other) const override
{
return test_identical(other);
}
}; };
} }
} }
...@@ -22,3 +22,8 @@ void ngraph::op::Log::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -22,3 +22,8 @@ void ngraph::op::Log::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta / x); adjoints.add_delta(x, delta / x);
} }
bool ngraph::op::Log::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -54,6 +54,7 @@ namespace ngraph ...@@ -54,6 +54,7 @@ namespace ngraph
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override; const std::shared_ptr<Node>& delta) override;
bool is_functionally_identical(const Node&) const override;
}; };
} }
} }
...@@ -150,7 +150,7 @@ op::MaxPool::MaxPool(const std::shared_ptr<Node>& arg, const Shape& window_shape ...@@ -150,7 +150,7 @@ op::MaxPool::MaxPool(const std::shared_ptr<Node>& arg, const Shape& window_shape
bool op::MaxPool::is_functionally_identical(const Node& other) const bool op::MaxPool::is_functionally_identical(const Node& other) const
{ {
bool rc = true; bool rc = true;
if (Node::is_functionally_identical(other)) if (Node::test_identical(other))
{ {
const MaxPool& rhs = dynamic_cast<const MaxPool&>(other); const MaxPool& rhs = dynamic_cast<const MaxPool&>(other);
rc &= m_window_shape == rhs.m_window_shape; rc &= m_window_shape == rhs.m_window_shape;
......
...@@ -33,3 +33,8 @@ void ngraph::op::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -33,3 +33,8 @@ void ngraph::op::Maximum::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta( adjoints.add_delta(
y, delta * make_shared<op::Convert>(make_shared<op::Greater>(y, x), y->get_element_type())); y, delta * make_shared<op::Convert>(make_shared<op::Greater>(y, x), y->get_element_type()));
} }
bool ngraph::op::Maximum::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -53,6 +53,7 @@ namespace ngraph ...@@ -53,6 +53,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Maximum>(new_args.at(0), new_args.at(1)); return std::make_shared<Maximum>(new_args.at(0), new_args.at(1));
} }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -34,3 +34,8 @@ void ngraph::op::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -34,3 +34,8 @@ void ngraph::op::Minimum::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta( adjoints.add_delta(
y, delta * make_shared<op::Convert>(make_shared<op::Less>(y, x), y->get_element_type())); y, delta * make_shared<op::Convert>(make_shared<op::Less>(y, x), y->get_element_type()));
} }
bool ngraph::op::Minimum::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -53,6 +53,7 @@ namespace ngraph ...@@ -53,6 +53,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Minimum>(new_args.at(0), new_args.at(1)); return std::make_shared<Minimum>(new_args.at(0), new_args.at(1));
} }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -23,3 +23,8 @@ void ngraph::op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -23,3 +23,8 @@ void ngraph::op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta * y); adjoints.add_delta(x, delta * y);
adjoints.add_delta(y, x * delta); adjoints.add_delta(y, x * delta);
} }
bool ngraph::op::Multiply::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -53,6 +53,7 @@ namespace ngraph ...@@ -53,6 +53,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Multiply>(new_args.at(0), new_args.at(1)); return std::make_shared<Multiply>(new_args.at(0), new_args.at(1));
} }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -21,3 +21,8 @@ void ngraph::op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -21,3 +21,8 @@ void ngraph::op::Negative::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, -delta); adjoints.add_delta(x, -delta);
} }
bool ngraph::op::Negative::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -54,6 +54,7 @@ namespace ngraph ...@@ -54,6 +54,7 @@ namespace ngraph
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override; const std::shared_ptr<Node>& delta) override;
bool is_functionally_identical(const Node&) const override;
}; };
} }
inline std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0) inline std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0)
......
...@@ -34,3 +34,8 @@ op::Not::Not(const std::shared_ptr<Node>& arg) ...@@ -34,3 +34,8 @@ op::Not::Not(const std::shared_ptr<Node>& arg)
arg) arg)
{ {
} }
bool ngraph::op::Not::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -48,6 +48,7 @@ namespace ngraph ...@@ -48,6 +48,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Not>(new_args.at(0)); return std::make_shared<Not>(new_args.at(0));
} }
bool is_functionally_identical(const Node&) const override;
}; };
} }
} }
...@@ -53,6 +53,10 @@ namespace ngraph ...@@ -53,6 +53,10 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<NotEqual>(new_args.at(0), new_args.at(1)); return std::make_shared<NotEqual>(new_args.at(0), new_args.at(1));
} }
bool is_functionally_identical(const Node& other) const override
{
return test_identical(other);
}
}; };
} }
} }
...@@ -45,7 +45,7 @@ op::OneHot::OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t ...@@ -45,7 +45,7 @@ op::OneHot::OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t
bool op::OneHot::is_functionally_identical(const Node& other) const bool op::OneHot::is_functionally_identical(const Node& other) const
{ {
bool rc = true; bool rc = true;
if (Node::is_functionally_identical(other)) if (Node::test_identical(other))
{ {
const OneHot& rhs = dynamic_cast<const OneHot&>(other); const OneHot& rhs = dynamic_cast<const OneHot&>(other);
rc &= m_shape == rhs.m_shape; rc &= m_shape == rhs.m_shape;
......
...@@ -28,3 +28,8 @@ void ngraph::op::Power::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -28,3 +28,8 @@ void ngraph::op::Power::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta * y * shared_from_this() / x); adjoints.add_delta(x, delta * y * shared_from_this() / x);
adjoints.add_delta(y, delta * shared_from_this() * log_x); adjoints.add_delta(y, delta * shared_from_this() * log_x);
} }
bool ngraph::op::Power::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -53,6 +53,7 @@ namespace ngraph ...@@ -53,6 +53,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Power>(new_args.at(0), new_args.at(1)); return std::make_shared<Power>(new_args.at(0), new_args.at(1));
} }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -90,8 +90,3 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee, ...@@ -90,8 +90,3 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
add_output(input_reductee.get_element_type(), result_shape); add_output(input_reductee.get_element_type(), result_shape);
} }
bool op::Reduce::is_functionally_identical(const Node& other) const
{
return false;
}
...@@ -110,8 +110,6 @@ namespace ngraph ...@@ -110,8 +110,6 @@ namespace ngraph
} }
/// \return The axis positions (0-based) to be eliminated through reduction. /// \return The axis positions (0-based) to be eliminated through reduction.
const AxisSet& get_reduction_axes() const { return m_reduction_axes; } const AxisSet& get_reduction_axes() const { return m_reduction_axes; }
bool is_functionally_identical(const Node&) const override;
protected: protected:
std::shared_ptr<Function> m_reduction_function; std::shared_ptr<Function> m_reduction_function;
AxisSet m_reduction_axes; AxisSet m_reduction_axes;
......
...@@ -55,6 +55,10 @@ namespace ngraph ...@@ -55,6 +55,10 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Remainder>(new_args.at(0), new_args.at(1)); return std::make_shared<Remainder>(new_args.at(0), new_args.at(1));
} }
bool is_functionally_identical(const Node& other) const override
{
return test_identical(other);
}
}; };
} }
} }
...@@ -136,7 +136,7 @@ void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -136,7 +136,7 @@ void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints,
bool op::ReplaceSlice::is_functionally_identical(const Node& other) const bool op::ReplaceSlice::is_functionally_identical(const Node& other) const
{ {
bool rc = true; bool rc = true;
if (Node::is_functionally_identical(other)) if (Node::test_identical(other))
{ {
const ReplaceSlice& slice = dynamic_cast<const ReplaceSlice&>(other); const ReplaceSlice& slice = dynamic_cast<const ReplaceSlice&>(other);
rc &= m_lower_bounds == slice.m_lower_bounds; rc &= m_lower_bounds == slice.m_lower_bounds;
......
...@@ -103,7 +103,7 @@ void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -103,7 +103,7 @@ void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints,
bool op::Reshape::is_functionally_identical(const Node& other) const bool op::Reshape::is_functionally_identical(const Node& other) const
{ {
bool rc = true; bool rc = true;
if (Node::is_functionally_identical(other)) if (Node::test_identical(other))
{ {
const Reshape& reshape = dynamic_cast<const Reshape&>(other); const Reshape& reshape = dynamic_cast<const Reshape&>(other);
rc &= m_input_order == reshape.m_input_order; rc &= m_input_order == reshape.m_input_order;
......
...@@ -54,7 +54,7 @@ void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -54,7 +54,7 @@ void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints,
bool op::Reverse::is_functionally_identical(const Node& other) const bool op::Reverse::is_functionally_identical(const Node& other) const
{ {
bool rc = true; bool rc = true;
if (Node::is_functionally_identical(other)) if (Node::test_identical(other))
{ {
const Reverse& obj = dynamic_cast<const Reverse&>(other); const Reverse& obj = dynamic_cast<const Reverse&>(other);
rc &= m_reversed_axes == obj.m_reversed_axes; rc &= m_reversed_axes == obj.m_reversed_axes;
......
...@@ -63,3 +63,8 @@ void ngraph::op::Select::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -63,3 +63,8 @@ void ngraph::op::Select::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta * p_as_x_type); adjoints.add_delta(x, delta * p_as_x_type);
adjoints.add_delta(y, delta * not_p_as_y_type); adjoints.add_delta(y, delta * not_p_as_y_type);
} }
bool ngraph::op::Select::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -54,6 +54,7 @@ namespace ngraph ...@@ -54,6 +54,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Select>(new_args.at(0), new_args.at(1), new_args.at(2)); return std::make_shared<Select>(new_args.at(0), new_args.at(1), new_args.at(2));
} }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -213,8 +213,3 @@ op::SelectAndScatter::SelectAndScatter(const std::shared_ptr<Node>& arg_selectee ...@@ -213,8 +213,3 @@ op::SelectAndScatter::SelectAndScatter(const std::shared_ptr<Node>& arg_selectee
// //
set_value_type_checked(input_selectee_element_type, input_selectee_shape); set_value_type_checked(input_selectee_element_type, input_selectee_shape);
} }
bool op::SelectAndScatter::is_functionally_identical(const Node& other) const
{
return false;
}
...@@ -108,8 +108,6 @@ namespace ngraph ...@@ -108,8 +108,6 @@ namespace ngraph
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
/// \return The window movement strides. /// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
bool is_functionally_identical(const Node&) const override;
protected: protected:
std::shared_ptr<Function> m_selection_function; std::shared_ptr<Function> m_selection_function;
std::shared_ptr<Function> m_scatter_function; std::shared_ptr<Function> m_scatter_function;
......
...@@ -53,6 +53,10 @@ namespace ngraph ...@@ -53,6 +53,10 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Sign>(new_args.at(0)); return std::make_shared<Sign>(new_args.at(0));
} }
bool is_functionally_identical(const Node& other) const override
{
return test_identical(other);
}
}; };
} }
} }
...@@ -23,3 +23,8 @@ void ngraph::op::Sin::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -23,3 +23,8 @@ void ngraph::op::Sin::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta * (std::make_shared<op::Cos>(x))); adjoints.add_delta(x, delta * (std::make_shared<op::Cos>(x)));
} }
bool ngraph::op::Sin::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -51,6 +51,7 @@ namespace ngraph ...@@ -51,6 +51,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Sin>(new_args.at(0)); return std::make_shared<Sin>(new_args.at(0));
} }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -23,3 +23,8 @@ void ngraph::op::Sinh::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -23,3 +23,8 @@ void ngraph::op::Sinh::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta * (std::make_shared<op::Cosh>(x))); adjoints.add_delta(x, delta * (std::make_shared<op::Cosh>(x)));
} }
bool ngraph::op::Sinh::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -51,6 +51,7 @@ namespace ngraph ...@@ -51,6 +51,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Sinh>(new_args.at(0)); return std::make_shared<Sinh>(new_args.at(0));
} }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -101,7 +101,7 @@ void op::Slice::generate_adjoints(autodiff::Adjoints& adjoints, const std::share ...@@ -101,7 +101,7 @@ void op::Slice::generate_adjoints(autodiff::Adjoints& adjoints, const std::share
bool op::Slice::is_functionally_identical(const Node& other) const bool op::Slice::is_functionally_identical(const Node& other) const
{ {
bool rc = true; bool rc = true;
if (Node::is_functionally_identical(other)) if (Node::test_identical(other))
{ {
const Slice& slice = dynamic_cast<const Slice&>(other); const Slice& slice = dynamic_cast<const Slice&>(other);
rc &= m_lower_bounds == slice.m_lower_bounds; rc &= m_lower_bounds == slice.m_lower_bounds;
......
...@@ -23,3 +23,8 @@ void ngraph::op::Sqrt::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -23,3 +23,8 @@ void ngraph::op::Sqrt::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta / (shared_from_this() + shared_from_this())); adjoints.add_delta(x, delta / (shared_from_this() + shared_from_this()));
} }
bool ngraph::op::Sqrt::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -51,6 +51,7 @@ namespace ngraph ...@@ -51,6 +51,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Sqrt>(new_args.at(0)); return std::make_shared<Sqrt>(new_args.at(0));
} }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -24,3 +24,8 @@ void ngraph::op::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -24,3 +24,8 @@ void ngraph::op::Subtract::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta); adjoints.add_delta(x, delta);
adjoints.add_delta(y, -delta); adjoints.add_delta(y, -delta);
} }
bool ngraph::op::Subtract::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -56,6 +56,7 @@ namespace ngraph ...@@ -56,6 +56,7 @@ namespace ngraph
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override; const std::shared_ptr<Node>& delta) override;
bool is_functionally_identical(const Node&) const override;
}; };
} }
inline std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0, inline std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0,
......
...@@ -64,7 +64,7 @@ void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ ...@@ -64,7 +64,7 @@ void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_
bool op::Sum::is_functionally_identical(const Node& other) const bool op::Sum::is_functionally_identical(const Node& other) const
{ {
bool rc = true; bool rc = true;
if (Node::is_functionally_identical(other)) if (Node::test_identical(other))
{ {
const Sum& slice = dynamic_cast<const Sum&>(other); const Sum& slice = dynamic_cast<const Sum&>(other);
rc &= m_reduction_axes == slice.m_reduction_axes; rc &= m_reduction_axes == slice.m_reduction_axes;
......
...@@ -26,3 +26,8 @@ void ngraph::op::Tan::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -26,3 +26,8 @@ void ngraph::op::Tan::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta / (c * c)); adjoints.add_delta(x, delta / (c * c));
} }
bool ngraph::op::Tan::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -51,6 +51,7 @@ namespace ngraph ...@@ -51,6 +51,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Tan>(new_args.at(0)); return std::make_shared<Tan>(new_args.at(0));
} }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -26,3 +26,8 @@ void ngraph::op::Tanh::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -26,3 +26,8 @@ void ngraph::op::Tanh::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta / (c * c)); adjoints.add_delta(x, delta / (c * c));
} }
bool ngraph::op::Tanh::is_functionally_identical(const Node& other) const
{
return test_identical(other);
}
...@@ -51,6 +51,7 @@ namespace ngraph ...@@ -51,6 +51,7 @@ namespace ngraph
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Tanh>(new_args.at(0)); return std::make_shared<Tanh>(new_args.at(0));
} }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -53,8 +53,3 @@ const Nodes& op::XLAGetTupleElement::get_tuple_elements() const ...@@ -53,8 +53,3 @@ const Nodes& op::XLAGetTupleElement::get_tuple_elements() const
{ {
return get_tuple_value()->get_tuple_elements(); return get_tuple_value()->get_tuple_elements();
} }
bool op::XLAGetTupleElement::is_functionally_identical(const Node& other) const
{
return false;
}
...@@ -64,8 +64,6 @@ namespace ngraph ...@@ -64,8 +64,6 @@ namespace ngraph
/// \return The index of the tuple element to get. /// \return The index of the tuple element to get.
size_t get_n() const { return m_n; } size_t get_n() const { return m_n; }
bool is_functionally_identical(const Node&) const override;
protected: protected:
std::shared_ptr<XLANode> m_arg; std::shared_ptr<XLANode> m_arg;
size_t m_n; size_t m_n;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment