Unverified Commit 7b1dc3e3 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

fix some is_functionally_identical methods (#365)

parent 7df687c1
...@@ -45,3 +45,19 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -45,3 +45,19 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, make_shared<op::Sum>(delta, m_broadcast_axes)); adjoints.add_delta(x, make_shared<op::Sum>(delta, m_broadcast_axes));
} }
bool op::Broadcast::is_functionally_identical(const Node& other) const
{
bool rc = true;
if (Node::is_functionally_identical(other))
{
const Broadcast& obj = dynamic_cast<const Broadcast&>(other);
rc &= m_shape == obj.m_shape;
rc &= m_broadcast_axes == obj.m_broadcast_axes;
}
else
{
rc = false;
}
return rc;
}
...@@ -73,6 +73,8 @@ namespace ngraph ...@@ -73,6 +73,8 @@ namespace ngraph
/// \return An set containing the indices of the broadcast axes (0-based). /// \return An set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; } const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
const Shape& get_broadcast_shape() const { return m_shape; } const Shape& get_broadcast_shape() const { return m_shape; }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override; const std::shared_ptr<Node>& delta) override;
......
...@@ -36,3 +36,18 @@ void ngraph::op::Convert::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -36,3 +36,18 @@ void ngraph::op::Convert::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, std::make_shared<op::Convert>(delta, x->get_element_type())); adjoints.add_delta(x, std::make_shared<op::Convert>(delta, x->get_element_type()));
} }
bool op::Convert::is_functionally_identical(const Node& other) const
{
bool rc = true;
if (Node::is_functionally_identical(other))
{
const Convert& obj = dynamic_cast<const Convert&>(other);
rc &= m_element_type == obj.m_element_type;
}
else
{
rc = false;
}
return rc;
}
...@@ -60,6 +60,8 @@ namespace ngraph ...@@ -60,6 +60,8 @@ namespace ngraph
} }
const element::Type& get_convert_element_type() const { return m_element_type; } const element::Type& get_convert_element_type() const { return m_element_type; }
bool is_functionally_identical(const Node&) const override;
protected: protected:
const ngraph::element::Type& m_element_type; const ngraph::element::Type& m_element_type;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -30,3 +30,8 @@ op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t ...@@ -30,3 +30,8 @@ op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t
set_value_type_checked(arg->get_output_element_type(n), arg->get_output_shape(n)); set_value_type_checked(arg->get_output_element_type(n), arg->get_output_shape(n));
} }
bool op::GetOutputElement::is_functionally_identical(const Node& other) const
{
return false;
}
...@@ -58,6 +58,8 @@ namespace ngraph ...@@ -58,6 +58,8 @@ namespace ngraph
/// \return The index of the tuple element to get. /// \return The index of the tuple element to get.
size_t get_n() const { return m_n; } size_t get_n() const { return m_n; }
bool is_functionally_identical(const Node&) const override;
protected: protected:
size_t m_n; size_t m_n;
}; };
......
...@@ -127,3 +127,8 @@ op::ReduceWindow::ReduceWindow(const std::shared_ptr<Node>& arg_reductee, ...@@ -127,3 +127,8 @@ op::ReduceWindow::ReduceWindow(const std::shared_ptr<Node>& arg_reductee,
set_value_type_checked(input_reductee.get_element_type(), result_shape); set_value_type_checked(input_reductee.get_element_type(), result_shape);
} }
bool op::ReduceWindow::is_functionally_identical(const Node& other) const
{
return false;
}
...@@ -85,6 +85,8 @@ namespace ngraph ...@@ -85,6 +85,8 @@ namespace ngraph
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
/// \return The window movement strides. /// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
bool is_functionally_identical(const Node&) const override;
protected: protected:
std::shared_ptr<Function> m_reduction_function; std::shared_ptr<Function> m_reduction_function;
Shape m_window_shape; Shape m_window_shape;
......
...@@ -50,3 +50,18 @@ void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -50,3 +50,18 @@ void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, make_shared<op::Reverse>(delta, m_reversed_axes)); adjoints.add_delta(x, make_shared<op::Reverse>(delta, m_reversed_axes));
} }
bool op::Reverse::is_functionally_identical(const Node& other) const
{
bool rc = true;
if (Node::is_functionally_identical(other))
{
const Reverse& obj = dynamic_cast<const Reverse&>(other);
rc &= m_reversed_axes == obj.m_reversed_axes;
}
else
{
rc = false;
}
return rc;
}
...@@ -60,6 +60,8 @@ namespace ngraph ...@@ -60,6 +60,8 @@ namespace ngraph
/// \return The set of axes to reverse. /// \return The set of axes to reverse.
const AxisSet& get_reversed_axes() const { return m_reversed_axes; } const AxisSet& get_reversed_axes() const { return m_reversed_axes; }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override; const std::shared_ptr<Node>& delta) override;
......
...@@ -213,3 +213,8 @@ op::SelectAndScatter::SelectAndScatter(const std::shared_ptr<Node>& arg_selectee ...@@ -213,3 +213,8 @@ op::SelectAndScatter::SelectAndScatter(const std::shared_ptr<Node>& arg_selectee
// //
set_value_type_checked(input_selectee_element_type, input_selectee_shape); set_value_type_checked(input_selectee_element_type, input_selectee_shape);
} }
bool op::SelectAndScatter::is_functionally_identical(const Node& other) const
{
return false;
}
...@@ -88,6 +88,8 @@ namespace ngraph ...@@ -88,6 +88,8 @@ namespace ngraph
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
/// \return The window movement strides. /// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
bool is_functionally_identical(const Node&) const override;
protected: protected:
std::shared_ptr<Function> m_selection_function; std::shared_ptr<Function> m_selection_function;
std::shared_ptr<Function> m_scatter_function; std::shared_ptr<Function> m_scatter_function;
......
...@@ -53,3 +53,8 @@ const Nodes& op::XLAGetTupleElement::get_tuple_elements() const ...@@ -53,3 +53,8 @@ const Nodes& op::XLAGetTupleElement::get_tuple_elements() const
{ {
return get_tuple_value()->get_tuple_elements(); return get_tuple_value()->get_tuple_elements();
} }
bool op::XLAGetTupleElement::is_functionally_identical(const Node& other) const
{
return false;
}
...@@ -64,6 +64,8 @@ namespace ngraph ...@@ -64,6 +64,8 @@ namespace ngraph
/// \return The index of the tuple element to get. /// \return The index of the tuple element to get.
size_t get_n() const { return m_n; } size_t get_n() const { return m_n; }
bool is_functionally_identical(const Node&) const override;
protected: protected:
std::shared_ptr<XLANode> m_arg; std::shared_ptr<XLANode> m_arg;
size_t m_n; size_t m_n;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment