Commit 7cc2a41f authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Check for element-wise equality at construction and simplify zero/one checks later (#3782)

parent 7617d385
......@@ -335,17 +335,8 @@ bool ngraph::is_equal_to_const_value(std::string const_value, const Output<Node>
{
if (auto rc = as_type_ptr<ngraph::op::Constant>(reduce_constant.get_node_shared_ptr()))
{
auto cshape = rc->get_shape();
size_t n = shape_size(cshape);
// way to construct a constant of a given type, shape, value
std::vector<std::string> vector_zero{n, const_value};
auto constant_val_op =
std::make_shared<ngraph::op::Constant>(rc->get_element_type(), cshape, vector_zero);
// way to compare elements to const_value
size_t n_bytes = n * rc->get_element_type().size();
NGRAPH_DEBUG << "Comparing " << n_bytes << " bytes";
return !memcmp(constant_val_op->get_data_ptr(), rc->get_data_ptr(), n_bytes);
return (rc->get_all_data_elements_bitwise_identical() &&
rc->convert_value_to_string(0) == const_value);
}
else
{
......
......@@ -70,6 +70,7 @@ namespace ngraph
write_values(values);
}
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
/// \brief Constructs a tensor constant
......@@ -128,6 +129,7 @@ namespace ngraph
}
}
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
/// \brief Constructs a tensor constant with the same initialization value copied across
......@@ -146,6 +148,7 @@ namespace ngraph
host_alignment()));
std::memcpy(m_data->get_ptr(), data, size);
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
virtual ~Constant() override;
......@@ -246,6 +249,10 @@ namespace ngraph
bool is_constant() const override { return true; }
bool are_all_data_elements_bitwise_identical() const;
bool get_all_data_elements_bitwise_identical() const
{
return m_all_elements_bitwise_identical;
}
std::string convert_value_to_string(size_t index) const;
protected:
......@@ -343,6 +350,7 @@ namespace ngraph
element::Type m_element_type;
Shape m_shape{};
std::unique_ptr<runtime::AlignedBuffer> m_data;
bool m_all_elements_bitwise_identical;
Constant(const Constant&) = delete;
Constant operator=(const Constant&) = delete;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment