Commit 0bb9368a authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

DropOut for INT (#2029)

parent ed14b94f
......@@ -128,6 +128,7 @@
#include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/state/rng_state.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/runtime/reference/allreduce.hpp"
......@@ -175,6 +176,7 @@ private:
bool m_performance_counters_enabled = false;
std::unordered_map<const Node*, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes;
std::unordered_map<const Node*, std::unique_ptr<RNGState>> m_states;
std::unique_ptr<AlignedBuffer> m_temporary_memory;
void* get_temporary_pointer(size_t offset) { return m_temporary_memory->get_ptr(offset); }
......@@ -332,8 +334,19 @@ private:
}
case OP_TYPEID::GenerateMask:
{
throw ngraph_error(
"GenerateMask is an experimental op that's only supported on CPU backend");
if (instance.m_states.count(&node) == 0)
{
const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node);
instance.m_states[&node] = std::unique_ptr<ngraph::RNGState>(
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability()));
}
bool training = static_cast<bool>(static_cast<const T*>(args[0])[0]);
auto state = instance.m_states.at(&node).get();
size_t element_count = shape_size(node.get_output_shape(0));
reference::generate_mask<T>(
reinterpret_cast<T*>(out[0]), element_count, state, training);
break;
}
case OP_TYPEID::GetOutputElement:
{
......
......@@ -7,7 +7,6 @@ batchnorm_fprop_inference_b2c2h2w1
batchnorm_fprop_bprop
batchnorm_fprop_bprop_2step
computation_reuse
generate_mask
topk_int64
topk_3d_large_input_max
topk_3d_large_input_min
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment