Commit b56c44fe authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Update add_delta_to_slice to tolerate dynamic shapes (#1962)

parent b00530a5
......@@ -177,8 +177,9 @@ void autodiff::Adjoints::add_delta_to_slice(const std::shared_ptr<Node>& x,
const Coordinate& upper_bounds,
const Strides& strides)
{
if (x->get_output_element_type(0) != delta->get_output_element_type(0) ||
x->get_output_shape(0).size() != delta->get_output_shape(0).size())
if (!(x->get_output_element_type(0).compatible(delta->get_output_element_type(0))) ||
!(x->get_output_partial_shape(0).rank().compatible(
delta->get_output_partial_shape(0).rank())))
{
throw ngraph_error(
"Autodiff internal error: Mismatch on backprop and op in add_delta_to_slice.");
......
......@@ -192,6 +192,11 @@ std::ostream& element::operator<<(std::ostream& out, const element::Type& obj)
return out;
}
bool element::Type::compatible(element::Type t) const
{
return (is_dynamic() || t.is_dynamic() || *this == t);
}
bool element::Type::merge(element::Type& dst, const element::Type& t1, const element::Type& t2)
{
if (t1.is_dynamic())
......
......@@ -77,6 +77,11 @@ namespace ngraph
/// Returns true if the type is floating point, else false.
bool get_is_real() const { return m_is_real; }
/// \brief Checks whether this element type is merge-compatible with `t`.
/// \param t The element type to compare this element type to.
/// \return `true` if this element type is compatible with `t`, else `false`.
bool compatible(element::Type t) const;
/// \brief Merges two element types t1 and t2, writing the result into dst and
/// returning true if successful, else returning false.
///
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment