Unverified Commit 7b1dc3e3 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

fix some is_functionally_identical methods (#365)

parent 7df687c1
......@@ -45,3 +45,19 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, make_shared<op::Sum>(delta, m_broadcast_axes));
}
bool op::Broadcast::is_functionally_identical(const Node& other) const
{
bool rc = true;
if (Node::is_functionally_identical(other))
{
const Broadcast& obj = dynamic_cast<const Broadcast&>(other);
rc &= m_shape == obj.m_shape;
rc &= m_broadcast_axes == obj.m_broadcast_axes;
}
else
{
rc = false;
}
return rc;
}
......@@ -73,6 +73,8 @@ namespace ngraph
/// \return An set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
const Shape& get_broadcast_shape() const { return m_shape; }
bool is_functionally_identical(const Node&) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -36,3 +36,18 @@ void ngraph::op::Convert::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, std::make_shared<op::Convert>(delta, x->get_element_type()));
}
bool op::Convert::is_functionally_identical(const Node& other) const
{
bool rc = true;
if (Node::is_functionally_identical(other))
{
const Convert& obj = dynamic_cast<const Convert&>(other);
rc &= m_element_type == obj.m_element_type;
}
else
{
rc = false;
}
return rc;
}
......@@ -60,6 +60,8 @@ namespace ngraph
}
const element::Type& get_convert_element_type() const { return m_element_type; }
bool is_functionally_identical(const Node&) const override;
protected:
const ngraph::element::Type& m_element_type;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -30,3 +30,8 @@ op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t
set_value_type_checked(arg->get_output_element_type(n), arg->get_output_shape(n));
}
bool op::GetOutputElement::is_functionally_identical(const Node& other) const
{
return false;
}
......@@ -58,6 +58,8 @@ namespace ngraph
/// \return The index of the tuple element to get.
size_t get_n() const { return m_n; }
bool is_functionally_identical(const Node&) const override;
protected:
size_t m_n;
};
......
......@@ -127,3 +127,8 @@ op::ReduceWindow::ReduceWindow(const std::shared_ptr<Node>& arg_reductee,
set_value_type_checked(input_reductee.get_element_type(), result_shape);
}
bool op::ReduceWindow::is_functionally_identical(const Node& other) const
{
return false;
}
......@@ -85,6 +85,8 @@ namespace ngraph
const Shape& get_window_shape() const { return m_window_shape; }
/// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
bool is_functionally_identical(const Node&) const override;
protected:
std::shared_ptr<Function> m_reduction_function;
Shape m_window_shape;
......
......@@ -50,3 +50,18 @@ void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, make_shared<op::Reverse>(delta, m_reversed_axes));
}
bool op::Reverse::is_functionally_identical(const Node& other) const
{
bool rc = true;
if (Node::is_functionally_identical(other))
{
const Reverse& obj = dynamic_cast<const Reverse&>(other);
rc &= m_reversed_axes == obj.m_reversed_axes;
}
else
{
rc = false;
}
return rc;
}
......@@ -60,6 +60,8 @@ namespace ngraph
/// \return The set of axes to reverse.
const AxisSet& get_reversed_axes() const { return m_reversed_axes; }
bool is_functionally_identical(const Node&) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -213,3 +213,8 @@ op::SelectAndScatter::SelectAndScatter(const std::shared_ptr<Node>& arg_selectee
//
set_value_type_checked(input_selectee_element_type, input_selectee_shape);
}
bool op::SelectAndScatter::is_functionally_identical(const Node& other) const
{
return false;
}
......@@ -88,6 +88,8 @@ namespace ngraph
const Shape& get_window_shape() const { return m_window_shape; }
/// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
bool is_functionally_identical(const Node&) const override;
protected:
std::shared_ptr<Function> m_selection_function;
std::shared_ptr<Function> m_scatter_function;
......
......@@ -53,3 +53,8 @@ const Nodes& op::XLAGetTupleElement::get_tuple_elements() const
{
return get_tuple_value()->get_tuple_elements();
}
bool op::XLAGetTupleElement::is_functionally_identical(const Node& other) const
{
return false;
}
......@@ -64,6 +64,8 @@ namespace ngraph
/// \return The index of the tuple element to get.
size_t get_n() const { return m_n; }
bool is_functionally_identical(const Node&) const override;
protected:
std::shared_ptr<XLANode> m_arg;
size_t m_n;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment