Unverified Commit db2de1b3 authored by Fenglei Tian's avatar Fenglei Tian Committed by GitHub

Merge branch 'master' into tfl/send_recv_op

parents c1680ce3 1960a44d
......@@ -216,6 +216,7 @@ namespace ngraph
virtual bool is_op() const { return false; }
virtual bool is_commutative() { return false; }
virtual bool is_dynamic() const;
virtual bool has_state() const { return false; }
size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&);
virtual std::ostream& write_short_description(std::ostream&) const;
......
......@@ -72,6 +72,8 @@ namespace ngraph
/// \brief Returns the seed value supplied to a random generator
uint64_t get_seed() const { return m_seed; }
bool get_use_seed() const { return m_use_seed; }
/// GenerateMask has state.
bool has_state() const override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override
......
......@@ -924,7 +924,8 @@ using namespace ngraph::runtime;
// Always enable nodes computing output tensors or nodes whose outputs might get
// overwritten due to inplace kernels
// TODO (jbobba) - Do we need to handle cacheability
if (computes_result(node.get()) || possibly_overwritten(node.get()))
if (computes_result(node.get()) || possibly_overwritten(node.get()) ||
node->has_state())
{
writer << " || 1";
}
......@@ -1423,7 +1424,7 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co
bool disable_caching =
(reuse_memory &&
!cacheable) // Check cacheability only if we are reusing intermediate tensors
|| computes_result(node.get()) || possibly_overwritten(node.get());
|| computes_result(node.get()) || possibly_overwritten(node.get()) || node->has_state();
vector<reference_wrapper<bool>> in_stale, out_stale;
for (const auto& name : in_names)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment