Commit b56c44fe authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Update add_delta_to_slice to tolerate dynamic shapes (#1962)

parent b00530a5
...@@ -177,8 +177,9 @@ void autodiff::Adjoints::add_delta_to_slice(const std::shared_ptr<Node>& x, ...@@ -177,8 +177,9 @@ void autodiff::Adjoints::add_delta_to_slice(const std::shared_ptr<Node>& x,
const Coordinate& upper_bounds, const Coordinate& upper_bounds,
const Strides& strides) const Strides& strides)
{ {
if (x->get_output_element_type(0) != delta->get_output_element_type(0) || if (!(x->get_output_element_type(0).compatible(delta->get_output_element_type(0))) ||
x->get_output_shape(0).size() != delta->get_output_shape(0).size()) !(x->get_output_partial_shape(0).rank().compatible(
delta->get_output_partial_shape(0).rank())))
{ {
throw ngraph_error( throw ngraph_error(
"Autodiff internal error: Mismatch on backprop and op in add_delta_to_slice."); "Autodiff internal error: Mismatch on backprop and op in add_delta_to_slice.");
......
...@@ -192,6 +192,11 @@ std::ostream& element::operator<<(std::ostream& out, const element::Type& obj) ...@@ -192,6 +192,11 @@ std::ostream& element::operator<<(std::ostream& out, const element::Type& obj)
return out; return out;
} }
bool element::Type::compatible(element::Type t) const
{
return (is_dynamic() || t.is_dynamic() || *this == t);
}
bool element::Type::merge(element::Type& dst, const element::Type& t1, const element::Type& t2) bool element::Type::merge(element::Type& dst, const element::Type& t1, const element::Type& t2)
{ {
if (t1.is_dynamic()) if (t1.is_dynamic())
......
...@@ -77,6 +77,11 @@ namespace ngraph ...@@ -77,6 +77,11 @@ namespace ngraph
/// Returns true if the type is floating point, else false. /// Returns true if the type is floating point, else false.
bool get_is_real() const { return m_is_real; } bool get_is_real() const { return m_is_real; }
/// \brief Checks whether this element type is merge-compatible with `t`.
/// \param t The element type to compare this element type to.
/// \return `true` if this element type is compatible with `t`, else `false`.
bool compatible(element::Type t) const;
/// \brief Merges two element types t1 and t2, writing the result into dst and /// \brief Merges two element types t1 and t2, writing the result into dst and
/// returning true if successful, else returning false. /// returning true if successful, else returning false.
/// ///
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment