Commit 1960a44d authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Disable caching when node has state. (#3043)

* Disable caching when node has state.

* Address PR feedback: move has_state to Node.

* Address PR feedback.

* Fix style error.
parent 79587d93
...@@ -216,6 +216,7 @@ namespace ngraph ...@@ -216,6 +216,7 @@ namespace ngraph
virtual bool is_op() const { return false; } virtual bool is_op() const { return false; }
virtual bool is_commutative() { return false; } virtual bool is_commutative() { return false; }
virtual bool is_dynamic() const; virtual bool is_dynamic() const;
virtual bool has_state() const { return false; }
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
virtual std::ostream& write_short_description(std::ostream&) const; virtual std::ostream& write_short_description(std::ostream&) const;
......
...@@ -72,6 +72,8 @@ namespace ngraph ...@@ -72,6 +72,8 @@ namespace ngraph
/// \brief Returns the seed value supplied to a random generator /// \brief Returns the seed value supplied to a random generator
uint64_t get_seed() const { return m_seed; } uint64_t get_seed() const { return m_seed; }
bool get_use_seed() const { return m_use_seed; } bool get_use_seed() const { return m_use_seed; }
/// GenerateMask has state.
bool has_state() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override const NodeVector& deltas) override
......
...@@ -924,7 +924,8 @@ using namespace ngraph::runtime; ...@@ -924,7 +924,8 @@ using namespace ngraph::runtime;
// Always enable nodes computing output tensors or nodes whose outputs might get // Always enable nodes computing output tensors or nodes whose outputs might get
// overwritten due to inplace kernels // overwritten due to inplace kernels
// TODO (jbobba) - Do we need to handle cacheability // TODO (jbobba) - Do we need to handle cacheability
if (computes_result(node.get()) || possibly_overwritten(node.get())) if (computes_result(node.get()) || possibly_overwritten(node.get()) ||
node->has_state())
{ {
writer << " || 1"; writer << " || 1";
} }
...@@ -1423,7 +1424,7 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co ...@@ -1423,7 +1424,7 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co
bool disable_caching = bool disable_caching =
(reuse_memory && (reuse_memory &&
!cacheable) // Check cacheability only if we are reusing intermediate tensors !cacheable) // Check cacheability only if we are reusing intermediate tensors
|| computes_result(node.get()) || possibly_overwritten(node.get()); || computes_result(node.get()) || possibly_overwritten(node.get()) || node->has_state();
vector<reference_wrapper<bool>> in_stale, out_stale; vector<reference_wrapper<bool>> in_stale, out_stale;
for (const auto& name : in_names) for (const auto& name : in_names)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment