Commit ba640dbb authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Misc fixes for partial shapes (#1987)

parent 37dc586c
...@@ -130,7 +130,7 @@ void op::AvgPoolBackprop::validate_and_infer_types() ...@@ -130,7 +130,7 @@ void op::AvgPoolBackprop::validate_and_infer_types()
m_window_movement_strides, m_window_movement_strides,
m_include_padding_in_avg_computation); m_include_padding_in_avg_computation);
const PartialShape& delta_shape = get_input_shape(0); const PartialShape& delta_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape)) NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape))
<< "Inferred forward output shape does not match delta shape (inferred forward output " << "Inferred forward output shape does not match delta shape (inferred forward output "
......
...@@ -30,7 +30,7 @@ op::Convert::Convert(const shared_ptr<Node>& arg, const element::Type& element_t ...@@ -30,7 +30,7 @@ op::Convert::Convert(const shared_ptr<Node>& arg, const element::Type& element_t
void op::Convert::validate_and_infer_types() void op::Convert::validate_and_infer_types()
{ {
set_output_type(0, m_element_type, get_input_shape(0)); set_output_type(0, m_element_type, get_input_partial_shape(0));
} }
shared_ptr<Node> op::Convert::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Convert::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -476,7 +476,7 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types() ...@@ -476,7 +476,7 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
const PartialShape& data_batch_shape = get_input_partial_shape(0); const PartialShape& data_batch_shape = get_input_partial_shape(0);
element::Type data_batch_et = get_input_element_type(0); element::Type data_batch_et = get_input_element_type(0);
const PartialShape& delta_shape = get_input_shape(1); const PartialShape& delta_shape = get_input_partial_shape(1);
element::Type delta_et = get_input_element_type(1); element::Type delta_et = get_input_element_type(1);
element::Type forward_result_et; element::Type forward_result_et;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment