Commit 64e1dbe9 authored by gaurides's avatar gaurides Committed by Scott Cyphers

Use all args for dropout (#3069)

parent 33c74139
...@@ -38,13 +38,13 @@ namespace ngraph ...@@ -38,13 +38,13 @@ namespace ngraph
auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name()); auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto arg4_buffer_index = external_function->get_buffer_index(args[4].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name()); auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto out1_buffer_index = external_function->get_buffer_index(out[1].get_name()); auto out1_buffer_index = external_function->get_buffer_index(out[1].get_name());
size_t element_count = out[0].get_size(); size_t element_count = out[0].get_size();
bool use_seed = drop->get_use_seed(); bool use_seed = drop->get_use_seed();
double keep_prob = drop->get_keep_prob();
// Note: for performance optimization in addition to parallel RNG with multiple, // Note: for performance optimization in addition to parallel RNG with multiple,
// threads, we create, initialize and advance each msr here in builder instead of // threads, we create, initialize and advance each msr here in builder instead of
...@@ -56,7 +56,7 @@ namespace ngraph ...@@ -56,7 +56,7 @@ namespace ngraph
std::vector<std::minstd_rand> vmsr(nthr); std::vector<std::minstd_rand> vmsr(nthr);
if (use_seed) if (use_seed)
{ {
uint32_t seed = drop->get_seed(); uint64_t seed = drop->get_seed();
for (size_t i = 0; i < nthr; i++) for (size_t i = 0; i < nthr; i++)
{ {
std::minstd_rand msr; std::minstd_rand msr;
...@@ -72,13 +72,15 @@ namespace ngraph ...@@ -72,13 +72,15 @@ namespace ngraph
element_count, element_count,
arg_buffer_index, arg_buffer_index,
arg1_buffer_index, arg1_buffer_index,
arg4_buffer_index,
out0_buffer_index, out0_buffer_index,
out1_buffer_index, out1_buffer_index,
keep_prob,
vmsr, vmsr,
use_seed](CPURuntimeContext* ctx, CPUExecutionContext* ectx) { use_seed](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
bool training = static_cast<bool>( bool training = static_cast<bool>(
static_cast<float*>(ctx->buffer_data[arg1_buffer_index])[0]); static_cast<float*>(ctx->buffer_data[arg1_buffer_index])[0]);
double keep_prob =
static_cast<double*>(ctx->buffer_data[arg4_buffer_index])[0];
runtime::cpu::kernel::generate_dropout( runtime::cpu::kernel::generate_dropout(
static_cast<float*>(ctx->buffer_data[arg_buffer_index]), static_cast<float*>(ctx->buffer_data[arg_buffer_index]),
static_cast<float*>(ctx->buffer_data[out0_buffer_index]), static_cast<float*>(ctx->buffer_data[out0_buffer_index]),
...@@ -96,13 +98,15 @@ namespace ngraph ...@@ -96,13 +98,15 @@ namespace ngraph
element_count, element_count,
arg_buffer_index, arg_buffer_index,
arg1_buffer_index, arg1_buffer_index,
arg4_buffer_index,
out0_buffer_index, out0_buffer_index,
out1_buffer_index, out1_buffer_index,
keep_prob,
vmsr, vmsr,
use_seed](CPURuntimeContext* ctx, CPUExecutionContext* ectx) { use_seed](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
bool training = static_cast<bool>( bool training = static_cast<bool>(
static_cast<double*>(ctx->buffer_data[arg1_buffer_index])[0]); static_cast<double*>(ctx->buffer_data[arg1_buffer_index])[0]);
double keep_prob =
static_cast<double*>(ctx->buffer_data[arg4_buffer_index])[0];
runtime::cpu::kernel::generate_dropout( runtime::cpu::kernel::generate_dropout(
static_cast<double*>(ctx->buffer_data[arg_buffer_index]), static_cast<double*>(ctx->buffer_data[arg_buffer_index]),
static_cast<double*>(ctx->buffer_data[out0_buffer_index]), static_cast<double*>(ctx->buffer_data[out0_buffer_index]),
......
...@@ -26,11 +26,9 @@ using namespace ngraph; ...@@ -26,11 +26,9 @@ using namespace ngraph;
op::Dropout::Dropout(const std::shared_ptr<Node>& input, op::Dropout::Dropout(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& gm_const, const std::shared_ptr<Node>& gm_const,
const std::shared_ptr<Node>& use_seed, const std::shared_ptr<Node>& use_seed,
const uint32_t seed, const std::shared_ptr<Node>& seed,
const double keep_prob) const std::shared_ptr<Node>& keep_prob)
: Op("Dropout", check_single_output_args({input, gm_const, use_seed})) : Op("Dropout", check_single_output_args({input, gm_const, use_seed, seed, keep_prob}))
, m_seed(seed)
, m_keep_prob(keep_prob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -41,13 +39,13 @@ op::Dropout::Dropout(const std::shared_ptr<Node>& input, ...@@ -41,13 +39,13 @@ op::Dropout::Dropout(const std::shared_ptr<Node>& input,
shared_ptr<Node> op::Dropout::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Dropout::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 3) if (new_args.size() != 5)
{ {
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
return make_shared<Dropout>( return make_shared<Dropout>(
new_args.at(0), new_args.at(1), new_args.at(2), m_seed, m_keep_prob); new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4));
} }
bool op::Dropout::get_use_seed() const bool op::Dropout::get_use_seed() const
...@@ -60,3 +58,14 @@ bool op::Dropout::get_use_seed() const ...@@ -60,3 +58,14 @@ bool op::Dropout::get_use_seed() const
} }
return use_seed; return use_seed;
} }
uint64_t op::Dropout::get_seed() const
{
uint64_t seed = 0;
if (auto const_op = dynamic_pointer_cast<op::Constant>(get_argument(3)))
{
auto seed_ptr = static_cast<const uint64_t*>(const_op->get_data_ptr());
seed = *seed_ptr;
}
return seed;
}
...@@ -29,20 +29,15 @@ namespace ngraph ...@@ -29,20 +29,15 @@ namespace ngraph
Dropout(const std::shared_ptr<Node>& input, Dropout(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& gm_const, const std::shared_ptr<Node>& gm_const,
const std::shared_ptr<Node>& use_seed, const std::shared_ptr<Node>& use_seed,
const uint32_t seed, const std::shared_ptr<Node>& seed,
const double keep_prob); // keep_prob = 1 - dropout_prob const std::shared_ptr<Node>& keep_prob); // keep_prob = 1 - dropout_prob
bool get_use_seed() const; bool get_use_seed() const;
uint32_t get_seed() const { return m_seed; } uint64_t get_seed() const;
double get_keep_prob() const { return m_keep_prob; } double get_keep_prob() const;
void set_seed(uint32_t new_seed) { m_seed = new_seed; }
void set_keep_prob(double new_keep_prob) { m_keep_prob = new_keep_prob; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
private:
uint32_t m_seed;
double m_keep_prob;
}; };
} }
} }
...@@ -923,8 +923,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_dropout() ...@@ -923,8 +923,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_dropout()
auto x = std::make_shared<pattern::op::Label>(element::f32, shape); auto x = std::make_shared<pattern::op::Label>(element::f32, shape);
auto x_label = std::make_shared<pattern::op::Label>(x, nullptr, NodeVector{x}); auto x_label = std::make_shared<pattern::op::Label>(x, nullptr, NodeVector{x});
uint32_t seed = 1234; uint64_t seed = 1234;
auto seed_label = std::make_shared<pattern::op::Label>(element::u32, Shape{0}); auto seed_label = std::make_shared<pattern::op::Label>(element::u64, Shape{0});
double value = 0.9; double value = 0.9;
auto value_const = ngraph::op::Constant::create(element::f32, Shape{1, 1, 2, 2}, {value}); auto value_const = ngraph::op::Constant::create(element::f32, Shape{1, 1, 2, 2}, {value});
...@@ -960,15 +960,28 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_dropout() ...@@ -960,15 +960,28 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_dropout()
NGRAPH_DEBUG << "training argument to GenerateMask must be constant"; NGRAPH_DEBUG << "training argument to GenerateMask must be constant";
return false; return false;
} }
if (!std::dynamic_pointer_cast<ngraph::op::Constant>(gm->get_argument(2)))
{
NGRAPH_DEBUG << "use_seed argument to GenerateMask must be constant";
return false;
}
if (!std::dynamic_pointer_cast<ngraph::op::Constant>(gm->get_argument(3)))
{
NGRAPH_DEBUG << "seed argument to GenerateMask must be constant";
return false;
}
if (!std::dynamic_pointer_cast<ngraph::op::Constant>(gm->get_argument(4)))
{
NGRAPH_DEBUG << "probability argument to GenerateMask must be constant";
return false;
}
auto gm_value = gm->get_probability(); auto dropout_n = std::make_shared<ngraph::op::Dropout>(pattern_map[x],
auto gm_seed = gm->get_seed(); gm->get_argument(0),
gm->get_argument(2),
auto training = gm->get_argument(0); //for training purpose this is always going to be 1 gm->get_argument(3),
auto use_seed_arg = gm->get_argument(2); // this is the use_seed node gm->get_argument(4));
auto dropout_n = std::make_shared<ngraph::op::Dropout>(
pattern_map[x], training, use_seed_arg, gm_seed, gm_value);
auto goe1 = std::make_shared<ngraph::op::GetOutputElement>(dropout_n, 0); auto goe1 = std::make_shared<ngraph::op::GetOutputElement>(dropout_n, 0);
ngraph::replace_node(m.get_match_root(), goe1); ngraph::replace_node(m.get_match_root(), goe1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment