Commit 938c2a6a authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Check for element-wise equality at construction and simplify zero/one checks later (#3785)

parent 7f5ad243
...@@ -327,17 +327,8 @@ bool ngraph::is_equal_to_const_value(std::string const_value, const Output<Node> ...@@ -327,17 +327,8 @@ bool ngraph::is_equal_to_const_value(std::string const_value, const Output<Node>
{ {
if (auto rc = dynamic_pointer_cast<ngraph::op::Constant>(reduce_constant.get_node_shared_ptr())) if (auto rc = dynamic_pointer_cast<ngraph::op::Constant>(reduce_constant.get_node_shared_ptr()))
{ {
auto cshape = rc->get_shape(); return (rc->get_all_data_elements_bitwise_identical() &&
size_t n = shape_size(cshape); rc->convert_value_to_string(0) == const_value);
// way to construct a constant of a given type, shape, value
std::vector<std::string> vector_zero{n, const_value};
auto constant_val_op =
std::make_shared<ngraph::op::Constant>(rc->get_element_type(), cshape, vector_zero);
// way to compare elements to const_value
size_t n_bytes = n * rc->get_element_type().size();
NGRAPH_DEBUG << "Comparing " << n_bytes << " bytes";
return !memcmp(constant_val_op->get_data_ptr(), rc->get_data_ptr(), n_bytes);
} }
else else
{ {
......
...@@ -70,6 +70,7 @@ namespace ngraph ...@@ -70,6 +70,7 @@ namespace ngraph
write_values(values); write_values(values);
} }
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
} }
/// \brief Constructs a tensor constant /// \brief Constructs a tensor constant
...@@ -128,6 +129,7 @@ namespace ngraph ...@@ -128,6 +129,7 @@ namespace ngraph
} }
} }
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
} }
/// \brief Constructs a tensor constant with the same initialization value copied across /// \brief Constructs a tensor constant with the same initialization value copied across
...@@ -146,6 +148,7 @@ namespace ngraph ...@@ -146,6 +148,7 @@ namespace ngraph
host_alignment())); host_alignment()));
std::memcpy(m_data->get_ptr(), data, size); std::memcpy(m_data->get_ptr(), data, size);
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
} }
virtual ~Constant() override; virtual ~Constant() override;
...@@ -246,6 +249,10 @@ namespace ngraph ...@@ -246,6 +249,10 @@ namespace ngraph
bool is_constant() const override { return true; } bool is_constant() const override { return true; }
bool are_all_data_elements_bitwise_identical() const; bool are_all_data_elements_bitwise_identical() const;
bool get_all_data_elements_bitwise_identical() const
{
return m_all_elements_bitwise_identical;
}
std::string convert_value_to_string(size_t index) const; std::string convert_value_to_string(size_t index) const;
protected: protected:
...@@ -343,6 +350,7 @@ namespace ngraph ...@@ -343,6 +350,7 @@ namespace ngraph
element::Type m_element_type; element::Type m_element_type;
Shape m_shape{}; Shape m_shape{};
std::unique_ptr<runtime::AlignedBuffer> m_data; std::unique_ptr<runtime::AlignedBuffer> m_data;
bool m_all_elements_bitwise_identical;
Constant(const Constant&) = delete; Constant(const Constant&) = delete;
Constant operator=(const Constant&) = delete; Constant operator=(const Constant&) = delete;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment