Commit fa808c91 authored by Evgenya Stepyreva's avatar Evgenya Stepyreva Committed by Scott Cyphers

[ TI ] TensorIterator m_num_iterations fix (#4070)

* [ TI ] TensorIterator m_num_iterations fix

* style
parent 0d0bd8de
......@@ -468,6 +468,7 @@ std::shared_ptr<Node> op::TensorIterator::copy_with_new_args(const NodeVector& n
}
}
op->m_num_iterations = m_num_iterations;
auto func = std::make_shared<Function>(m_body->get_results(), m_body->get_parameters());
auto spec_func = specialize_function(
func, types, new_shapes, std::vector<void*>(new_args.size(), nullptr), false, true);
......
......@@ -300,6 +300,11 @@ namespace ngraph
void revalidate_and_infer_types_for_body_ops();
int64_t get_num_iterations() const { return m_num_iterations; }
void set_num_iterations(int64_t num_iterations)
{
m_num_iterations = num_iterations;
}
private:
// Find an input corresponding to value, adding one if necessary.
Input<Node> input_for_value(const Output<Node>& value);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment