Commit 64e1dbe9 authored by gaurides's avatar gaurides Committed by Scott Cyphers

Use all args for dropout (#3069)

parent 33c74139
......@@ -38,13 +38,13 @@ namespace ngraph
auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto arg4_buffer_index = external_function->get_buffer_index(args[4].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto out1_buffer_index = external_function->get_buffer_index(out[1].get_name());
size_t element_count = out[0].get_size();
bool use_seed = drop->get_use_seed();
double keep_prob = drop->get_keep_prob();
// Note: for performance optimization in addition to parallel RNG with multiple,
// threads, we create, initialize and advance each msr here in builder instead of
......@@ -56,7 +56,7 @@ namespace ngraph
std::vector<std::minstd_rand> vmsr(nthr);
if (use_seed)
{
uint32_t seed = drop->get_seed();
uint64_t seed = drop->get_seed();
for (size_t i = 0; i < nthr; i++)
{
std::minstd_rand msr;
......@@ -72,13 +72,15 @@ namespace ngraph
element_count,
arg_buffer_index,
arg1_buffer_index,
arg4_buffer_index,
out0_buffer_index,
out1_buffer_index,
keep_prob,
vmsr,
use_seed](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
bool training = static_cast<bool>(
static_cast<float*>(ctx->buffer_data[arg1_buffer_index])[0]);
double keep_prob =
static_cast<double*>(ctx->buffer_data[arg4_buffer_index])[0];
runtime::cpu::kernel::generate_dropout(
static_cast<float*>(ctx->buffer_data[arg_buffer_index]),
static_cast<float*>(ctx->buffer_data[out0_buffer_index]),
......@@ -96,13 +98,15 @@ namespace ngraph
element_count,
arg_buffer_index,
arg1_buffer_index,
arg4_buffer_index,
out0_buffer_index,
out1_buffer_index,
keep_prob,
vmsr,
use_seed](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
bool training = static_cast<bool>(
static_cast<double*>(ctx->buffer_data[arg1_buffer_index])[0]);
double keep_prob =
static_cast<double*>(ctx->buffer_data[arg4_buffer_index])[0];
runtime::cpu::kernel::generate_dropout(
static_cast<double*>(ctx->buffer_data[arg_buffer_index]),
static_cast<double*>(ctx->buffer_data[out0_buffer_index]),
......
......@@ -26,11 +26,9 @@ using namespace ngraph;
op::Dropout::Dropout(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& gm_const,
const std::shared_ptr<Node>& use_seed,
const uint32_t seed,
const double keep_prob)
: Op("Dropout", check_single_output_args({input, gm_const, use_seed}))
, m_seed(seed)
, m_keep_prob(keep_prob)
const std::shared_ptr<Node>& seed,
const std::shared_ptr<Node>& keep_prob)
: Op("Dropout", check_single_output_args({input, gm_const, use_seed, seed, keep_prob}))
{
constructor_validate_and_infer_types();
......@@ -41,13 +39,13 @@ op::Dropout::Dropout(const std::shared_ptr<Node>& input,
shared_ptr<Node> op::Dropout::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 3)
if (new_args.size() != 5)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Dropout>(
new_args.at(0), new_args.at(1), new_args.at(2), m_seed, m_keep_prob);
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4));
}
bool op::Dropout::get_use_seed() const
......@@ -60,3 +58,14 @@ bool op::Dropout::get_use_seed() const
}
return use_seed;
}
uint64_t op::Dropout::get_seed() const
{
uint64_t seed = 0;
if (auto const_op = dynamic_pointer_cast<op::Constant>(get_argument(3)))
{
auto seed_ptr = static_cast<const uint64_t*>(const_op->get_data_ptr());
seed = *seed_ptr;
}
return seed;
}
......@@ -29,20 +29,15 @@ namespace ngraph
Dropout(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& gm_const,
const std::shared_ptr<Node>& use_seed,
const uint32_t seed,
const double keep_prob); // keep_prob = 1 - dropout_prob
const std::shared_ptr<Node>& seed,
const std::shared_ptr<Node>& keep_prob); // keep_prob = 1 - dropout_prob
bool get_use_seed() const;
uint32_t get_seed() const { return m_seed; }
double get_keep_prob() const { return m_keep_prob; }
void set_seed(uint32_t new_seed) { m_seed = new_seed; }
void set_keep_prob(double new_keep_prob) { m_keep_prob = new_keep_prob; }
uint64_t get_seed() const;
double get_keep_prob() const;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
private:
uint32_t m_seed;
double m_keep_prob;
};
}
}
......@@ -923,8 +923,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_dropout()
auto x = std::make_shared<pattern::op::Label>(element::f32, shape);
auto x_label = std::make_shared<pattern::op::Label>(x, nullptr, NodeVector{x});
uint32_t seed = 1234;
auto seed_label = std::make_shared<pattern::op::Label>(element::u32, Shape{0});
uint64_t seed = 1234;
auto seed_label = std::make_shared<pattern::op::Label>(element::u64, Shape{0});
double value = 0.9;
auto value_const = ngraph::op::Constant::create(element::f32, Shape{1, 1, 2, 2}, {value});
......@@ -960,15 +960,28 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_dropout()
NGRAPH_DEBUG << "training argument to GenerateMask must be constant";
return false;
}
if (!std::dynamic_pointer_cast<ngraph::op::Constant>(gm->get_argument(2)))
{
NGRAPH_DEBUG << "use_seed argument to GenerateMask must be constant";
return false;
}
if (!std::dynamic_pointer_cast<ngraph::op::Constant>(gm->get_argument(3)))
{
NGRAPH_DEBUG << "seed argument to GenerateMask must be constant";
return false;
}
if (!std::dynamic_pointer_cast<ngraph::op::Constant>(gm->get_argument(4)))
{
NGRAPH_DEBUG << "probability argument to GenerateMask must be constant";
return false;
}
auto gm_value = gm->get_probability();
auto gm_seed = gm->get_seed();
auto training = gm->get_argument(0); //for training purpose this is always going to be 1
auto use_seed_arg = gm->get_argument(2); // this is the use_seed node
auto dropout_n = std::make_shared<ngraph::op::Dropout>(pattern_map[x],
gm->get_argument(0),
gm->get_argument(2),
gm->get_argument(3),
gm->get_argument(4));
auto dropout_n = std::make_shared<ngraph::op::Dropout>(
pattern_map[x], training, use_seed_arg, gm_seed, gm_value);
auto goe1 = std::make_shared<ngraph::op::GetOutputElement>(dropout_n, 0);
ngraph::replace_node(m.get_match_root(), goe1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment