Commit fa808c91 authored by Evgenya Stepyreva's avatar Evgenya Stepyreva Committed by Scott Cyphers

[ TI ] TensorIterator m_num_iterations fix (#4070)

* [ TI ] TensorIterator m_num_iterations fix

* style
parent 0d0bd8de
...@@ -468,6 +468,7 @@ std::shared_ptr<Node> op::TensorIterator::copy_with_new_args(const NodeVector& n ...@@ -468,6 +468,7 @@ std::shared_ptr<Node> op::TensorIterator::copy_with_new_args(const NodeVector& n
} }
} }
op->m_num_iterations = m_num_iterations;
auto func = std::make_shared<Function>(m_body->get_results(), m_body->get_parameters()); auto func = std::make_shared<Function>(m_body->get_results(), m_body->get_parameters());
auto spec_func = specialize_function( auto spec_func = specialize_function(
func, types, new_shapes, std::vector<void*>(new_args.size(), nullptr), false, true); func, types, new_shapes, std::vector<void*>(new_args.size(), nullptr), false, true);
......
...@@ -300,6 +300,11 @@ namespace ngraph ...@@ -300,6 +300,11 @@ namespace ngraph
void revalidate_and_infer_types_for_body_ops(); void revalidate_and_infer_types_for_body_ops();
int64_t get_num_iterations() const { return m_num_iterations; } int64_t get_num_iterations() const { return m_num_iterations; }
void set_num_iterations(int64_t num_iterations)
{
m_num_iterations = num_iterations;
}
private: private:
// Find an input corresponding to value, adding one if necessary. // Find an input corresponding to value, adding one if necessary.
Input<Node> input_for_value(const Output<Node>& value); Input<Node> input_for_value(const Output<Node>& value);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment