Unverified Commit f0552cc8 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

GenerateMask correction (#3029)

* GenerateMask correction
Add an attribute that controls if the seed should be set on each use
Convert to new virtual method for description implementatin

* Support for switching to dynamic attributes.

* GenerateMask changes in CPU backend (#3042)

* Add CPU builder and kernel for new GenerateMask API

* Remove dead code

* Fix unit-test, PR feedback, file permissions

* Disable new test for non-supporting backends

* Fix CI error

* Codegen support

* Style check

* Fix CI error
parent e51c5824
File mode changed from 100755 to 100644
...@@ -19,17 +19,45 @@ ...@@ -19,17 +19,45 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::GenerateMask::GenerateMask(const std::shared_ptr<Node>& training, const string op::GenerateMask::type_name{"GenerateMask"};
op::GenerateMask::GenerateMask()
: Op()
{
}
#if 0
// Not supported until all transformers use nodes instead of attributes
op::GenerateMask::GenerateMask(const Output<Node>& training,
const Output<Node>& shape,
const Output<Node>& probability,
const Output<Node>& seed,
const Output<Node>& use_seed,
const element::Type& element_type)
: Op({training, shape, probability, seed, use_seed})
, m_element_type(element_type)
{
}
#endif
op::GenerateMask::GenerateMask(const Output<Node>& training,
const Shape& shape, const Shape& shape,
const element::Type& element_type, const element::Type& element_type,
unsigned int seed, uint64_t seed,
double prob) double prob,
: Op("GenerateMask", check_single_output_args({training})) bool use_seed)
, m_shape(shape) : Op({training})
, m_element_type(element_type) , m_element_type(element_type)
, m_shape(shape)
, m_use_seed(use_seed)
, m_seed(seed) , m_seed(seed)
, m_probability(prob) , m_probability(prob)
{ {
set_argument(1, make_shared<op::Constant>(element::u64, Shape{shape.size()}, shape));
set_argument(2,
make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{use_seed}));
set_argument(3, make_shared<op::Constant>(element::u64, Shape{}, std::vector<uint64_t>{seed}));
set_argument(4, make_shared<op::Constant>(element::f64, Shape{}, std::vector<double>{prob}));
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -37,7 +65,7 @@ shared_ptr<Node> op::GenerateMask::copy_with_new_args(const NodeVector& new_args ...@@ -37,7 +65,7 @@ shared_ptr<Node> op::GenerateMask::copy_with_new_args(const NodeVector& new_args
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<GenerateMask>( return make_shared<GenerateMask>(
new_args.at(0), m_shape, m_element_type, m_seed, m_probability); new_args.at(0), m_shape, m_element_type, m_seed, m_probability, m_use_seed);
} }
void ngraph::op::GenerateMask::validate_and_infer_types() void ngraph::op::GenerateMask::validate_and_infer_types()
......
...@@ -29,21 +29,49 @@ namespace ngraph ...@@ -29,21 +29,49 @@ namespace ngraph
class GenerateMask : public op::Op class GenerateMask : public op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a GenerateMask node with a given shape, seed, /// \brief Constructs a GenerateMask node with a given shape, seed,
/// probability and training/inference mode /// probability and training/inference mode
GenerateMask(const std::shared_ptr<Node>& training, GenerateMask();
#if 0
/// Switch to dynamic arguments when all transformers have switched to using the node values
/// \brief Constructs a GenerateMask node with a given shape, seed,
/// probability and training/inference mode
GenerateMask(const Output<Node>& training,
const Output<Node>& shape,
const Output<Node>& probability,
const Output<Node>& seed,
const Output<Node>& use_seed,
const element::Type& element_type);
#endif
/// \brief Constructs a GenerateMask node with a given shape, seed,
/// probability and training/inference mode
GenerateMask(const Output<Node>& training,
const Shape& shape, const Shape& shape,
const element::Type& element_type, const element::Type& element_type,
unsigned int seed, uint64_t seed,
double prob); double prob,
bool use_seed = false);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
const element::Type& get_element_type() const { return m_element_type; }
void set_element_type(const element::Type& element_type)
{
m_element_type = element_type;
}
/// Deprecated accessor for transitional attributes
const Shape& get_mask_shape() const { return m_shape; }
/// \brief Returns the probability of a trial generating 1 (i.e. an element being kept) /// \brief Returns the probability of a trial generating 1 (i.e. an element being kept)
double get_probability() const { return m_probability; } double get_probability() const { return m_probability; }
/// \brief Returns the seed value supplied to a random generator /// \brief Returns the seed value supplied to a random generator
unsigned int get_seed() const { return m_seed; } uint64_t get_seed() const { return m_seed; }
bool get_use_seed() const { return m_use_seed; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override const NodeVector& deltas) override
...@@ -51,10 +79,12 @@ namespace ngraph ...@@ -51,10 +79,12 @@ namespace ngraph
} }
void validate_and_infer_types() override; void validate_and_infer_types() override;
Shape m_shape;
element::Type m_element_type; element::Type m_element_type;
unsigned int m_seed; // These will be deprecated
double m_probability; Shape m_shape;
bool m_use_seed{false};
uint64_t m_seed{0};
double m_probability{0.0};
}; };
} }
} }
...@@ -41,33 +41,93 @@ namespace ngraph ...@@ -41,33 +41,93 @@ namespace ngraph
size_t element_count = out[0].get_size(); size_t element_count = out[0].get_size();
auto arg2_buffer_index =
external_function->get_buffer_index(args[2].get_name()); //use_seed
auto arg3_buffer_index =
external_function->get_buffer_index(args[3].get_name()); //seed
auto arg4_buffer_index =
external_function->get_buffer_index(args[4].get_name()); //prob
auto seed_attr = gm->get_use_seed() ? gm->get_seed() : 0;
auto index = external_function->add_state( auto index = external_function->add_state(
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability())); ngraph::RNGState::create_rng_state(seed_attr, gm->get_probability()));
if (args[0].get_element_type() == element::f32) if (args[0].get_element_type() == element::f32)
{ {
functor = [&, index, element_count, arg_buffer_index, out_buffer_index]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { index,
element_count,
arg_buffer_index,
out_buffer_index,
arg2_buffer_index,
arg3_buffer_index,
arg4_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
bool training = static_cast<bool>( bool training = static_cast<bool>(
static_cast<float*>(ctx->buffer_data[arg_buffer_index])[0]); static_cast<float*>(ctx->buffer_data[arg_buffer_index])[0]);
// TODO: get shape when required
bool use_seed = static_cast<bool>(
static_cast<int32_t*>(ctx->buffer_data[arg2_buffer_index])[0]);
uint64_t seed =
static_cast<uint64_t*>(ctx->buffer_data[arg3_buffer_index])[0];
double prob = static_cast<double*>(ctx->buffer_data[arg4_buffer_index])[0];
if (use_seed == false)
{
reference::generate_mask( reference::generate_mask(
static_cast<float*>(ctx->buffer_data[out_buffer_index]), static_cast<float*>(ctx->buffer_data[out_buffer_index]),
element_count, element_count,
static_cast<RNGState*>(ctx->states[index]), static_cast<RNGState*>(ctx->states[index]),
training); training);
}
else
{
reference::generate_mask_no_state(
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
element_count,
training,
seed,
prob);
}
}; };
} }
else if (args[0].get_element_type() == element::f64) else if (args[0].get_element_type() == element::f64)
{ {
functor = [&, index, element_count, arg_buffer_index, out_buffer_index]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { index,
element_count,
arg_buffer_index,
out_buffer_index,
arg2_buffer_index,
arg3_buffer_index,
arg4_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
bool training = static_cast<bool>( bool training = static_cast<bool>(
static_cast<double*>(ctx->buffer_data[arg_buffer_index])[0]); static_cast<double*>(ctx->buffer_data[arg_buffer_index])[0]);
// TODO: get shape when required
bool use_seed = static_cast<bool>(
static_cast<int32_t*>(ctx->buffer_data[arg2_buffer_index])[0]);
uint64_t seed =
static_cast<uint64_t*>(ctx->buffer_data[arg3_buffer_index])[0];
double prob = static_cast<double*>(ctx->buffer_data[arg4_buffer_index])[0];
if (use_seed == false)
{
reference::generate_mask( reference::generate_mask(
static_cast<double*>(ctx->buffer_data[out_buffer_index]), static_cast<double*>(ctx->buffer_data[out_buffer_index]),
element_count, element_count,
static_cast<RNGState*>(ctx->states[index]), static_cast<RNGState*>(ctx->states[index]),
training); training);
}
else
{
reference::generate_mask_no_state(
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
element_count,
training,
seed,
prob);
}
}; };
} }
else else
......
...@@ -3965,10 +3965,25 @@ namespace ngraph ...@@ -3965,10 +3965,25 @@ namespace ngraph
writer << "auto state = static_cast<ngraph::RNGState*>(ctx->states[" << index writer << "auto state = static_cast<ngraph::RNGState*>(ctx->states[" << index
<< "]);\n"; << "]);\n";
writer << "bool training = static_cast<bool>(" << args[0].get_name() << "[0]);\n"; writer << "bool training = static_cast<bool>(" << args[0].get_name() << "[0]);\n";
writer << "reference::generate_mask("; writer << "bool use_seed = static_cast<bool>(" << args[2].get_name() << "[0]);\n";
writer << "uint64_t seed = static_cast<uint64_t>(" << args[3].get_name()
<< "[0]);\n";
writer << "double keep_prob = static_cast<double>(" << args[4].get_name()
<< "[0]);\n";
writer << "if (use_seed == false) \n";
writer << "{\n";
writer << " reference::generate_mask(\n";
writer << " " << out[0].get_name() << ",\n"; writer << " " << out[0].get_name() << ",\n";
writer << " " << out[0].get_size() << ",\n"; writer << " " << out[0].get_size() << ",\n";
writer << " state, training);\n"; writer << " state, training);\n";
writer << "}\n";
writer << "else {\n";
writer << " reference::generate_mask_no_state(\n";
writer << " " << out[0].get_name() << ",\n";
writer << " " << out[0].get_size() << ",\n";
writer << " training, seed, keep_prob);\n";
writer << "}\n";
writer.block_end(); writer.block_end();
} }
......
...@@ -3,6 +3,7 @@ computation_reuse ...@@ -3,6 +3,7 @@ computation_reuse
# int64 is not supprted by cuDNN # int64 is not supprted by cuDNN
dot_matrix_vector_int64 dot_matrix_vector_int64
generate_mask generate_mask
generate_mask2
# custom_mem is not implemented on GPU # custom_mem is not implemented on GPU
tensorview_custom_mem tensorview_custom_mem
# integer is not supported by cuDNN on backward pooling # integer is not supported by cuDNN on backward pooling
......
...@@ -10,6 +10,7 @@ embedding_lookup_10x1_arbitrary_index_type_int ...@@ -10,6 +10,7 @@ embedding_lookup_10x1_arbitrary_index_type_int
embedding_lookup_10x1_arbitrary_index_type_int64 embedding_lookup_10x1_arbitrary_index_type_int64
embedding_lookup_4x5_reverse embedding_lookup_4x5_reverse
generate_mask generate_mask
generate_mask2
replace_slice_3d replace_slice_3d
replace_slice_3d_strided replace_slice_3d_strided
replace_slice_3d_strided_different_strides replace_slice_3d_strided_different_strides
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <initializer_list> #include <initializer_list>
#include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
...@@ -357,17 +358,30 @@ private: ...@@ -357,17 +358,30 @@ private:
} }
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
{ {
bool use_seed = static_cast<bool>(args[2]->get_data_ptr<const int32_t>()[0]);
if (m_states.count(&node) == 0) if (m_states.count(&node) == 0)
{ {
const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node); const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node);
auto seed = use_seed ? gm->get_seed() : 0;
m_states[&node] = std::unique_ptr<ngraph::RNGState>( m_states[&node] = std::unique_ptr<ngraph::RNGState>(
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability())); ngraph::RNGState::create_rng_state(seed, gm->get_probability()));
} }
bool training = static_cast<bool>(args[0]->get_data_ptr<const T>()[0]); bool training = static_cast<bool>(args[0]->get_data_ptr<const T>()[0]);
auto state = m_states.at(&node).get(); auto state = m_states.at(&node).get();
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::generate_mask<T>(out[0]->get_data_ptr<T>(), element_count, state, training); if (!use_seed)
{
reference::generate_mask<T>(
out[0]->get_data_ptr<T>(), element_count, state, training);
}
else
{
uint64_t seed = static_cast<uint64_t>(args[3]->get_data_ptr<const T>()[0]);
double prob = static_cast<double>(args[4]->get_data_ptr<const T>()[0]);
reference::generate_mask_no_state<T>(
out[0]->get_data_ptr<T>(), element_count, training, seed, prob);
}
break; break;
} }
case OP_TYPEID::GetOutputElement: case OP_TYPEID::GetOutputElement:
......
...@@ -57,6 +57,7 @@ max_pool_2d_1channel_1image_overpadded ...@@ -57,6 +57,7 @@ max_pool_2d_1channel_1image_overpadded
max_pool_3d max_pool_3d
maxpool_bprop_larger_than_cache maxpool_bprop_larger_than_cache
generate_mask generate_mask
generate_mask2
avg_pool_3d avg_pool_3d
avg_pool_3d_uneven_strided_padded_include_in_computation avg_pool_3d_uneven_strided_padded_include_in_computation
quantize_dynamic_offset # Quantization/Dequantization is unimplemented quantize_dynamic_offset # Quantization/Dequantization is unimplemented
......
...@@ -37,6 +37,19 @@ namespace ngraph ...@@ -37,6 +37,19 @@ namespace ngraph
out[i] = training ? static_cast<T>(bd(gen)) : static_cast<T>(1); out[i] = training ? static_cast<T>(bd(gen)) : static_cast<T>(1);
} }
} }
template <typename T>
void generate_mask_no_state(
T* out, size_t count, bool training, uint32_t seed, double prob)
{
std::mt19937 gen(seed);
std::bernoulli_distribution bd(prob);
for (size_t i = 0; i < count; i++)
{
out[i] = training ? static_cast<T>(bd(gen)) : static_cast<T>(1);
}
}
} }
} }
} }
...@@ -1006,9 +1006,15 @@ static shared_ptr<ngraph::Function> ...@@ -1006,9 +1006,15 @@ static shared_ptr<ngraph::Function>
auto type = read_element_type(node_js.at("type")); auto type = read_element_type(node_js.at("type"));
auto seed = node_js.at("seed").get<unsigned int>(); auto seed = node_js.at("seed").get<unsigned int>();
auto probability = node_js.at("probability").get<double>(); auto probability = node_js.at("probability").get<double>();
auto use_seed_maybe = node_js.at("use_seed");
bool use_seed = false;
if (!use_seed_maybe.empty())
{
use_seed = use_seed_maybe.get<bool>();
}
node = node = make_shared<op::GenerateMask>(
make_shared<op::GenerateMask>(args[0], output_shape, type, seed, probability); args[0], output_shape, type, seed, probability, use_seed);
break; break;
} }
case OP_TYPEID::GetOutputElement: case OP_TYPEID::GetOutputElement:
...@@ -2024,8 +2030,9 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -2024,8 +2030,9 @@ static json write(const Node& n, bool binary_constant_data)
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
{ {
auto tmp = dynamic_cast<const op::GenerateMask*>(&n); auto tmp = dynamic_cast<const op::GenerateMask*>(&n);
node["output_shape"] = tmp->get_shape(); node["output_shape"] = tmp->get_mask_shape();
node["type"] = write_element_type(tmp->get_element_type()); node["type"] = write_element_type(tmp->get_element_type());
node["use_seed"] = tmp->get_use_seed();
node["seed"] = tmp->get_seed(); node["seed"] = tmp->get_seed();
node["probability"] = tmp->get_probability(); node["probability"] = tmp->get_probability();
break; break;
......
...@@ -6287,8 +6287,10 @@ NGRAPH_TEST(${BACKEND_NAME}, generate_mask) ...@@ -6287,8 +6287,10 @@ NGRAPH_TEST(${BACKEND_NAME}, generate_mask)
Shape result_shape{1, 128}; Shape result_shape{1, 128};
const unsigned int seed = 777; const unsigned int seed = 777;
auto training = op::Constant::create(element::f32, Shape{}, {1}); auto training = op::Constant::create(element::f32, Shape{}, {1});
auto gen_mask = make_shared<op::GenerateMask>(training, result_shape, element::f32, seed, 0.5); auto gen_mask =
auto gen_mask2 = make_shared<op::GenerateMask>(training, result_shape, element::f32, seed, 0.5); make_shared<op::GenerateMask>(training, result_shape, element::f32, seed, 0.5, false);
auto gen_mask2 =
make_shared<op::GenerateMask>(training, result_shape, element::f32, seed, 0.5, false);
auto f = make_shared<Function>(NodeVector{gen_mask, gen_mask2}, ParameterVector{}); auto f = make_shared<Function>(NodeVector{gen_mask, gen_mask2}, ParameterVector{});
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
...@@ -6312,6 +6314,42 @@ NGRAPH_TEST(${BACKEND_NAME}, generate_mask) ...@@ -6312,6 +6314,42 @@ NGRAPH_TEST(${BACKEND_NAME}, generate_mask)
ASSERT_FALSE(std::any_of(result2_2.begin(), result2_2.end(), is_not_zero_or_one)); ASSERT_FALSE(std::any_of(result2_2.begin(), result2_2.end(), is_not_zero_or_one));
} }
NGRAPH_TEST(${BACKEND_NAME}, generate_mask2)
{
Shape scalar{};
Shape result_shape{1, 128};
const unsigned int seed = 777;
auto training = op::Constant::create(element::f32, Shape{}, {1});
auto gen_mask =
make_shared<op::GenerateMask>(training, result_shape, element::f32, seed, 0.5, true);
auto gen_mask2 =
make_shared<op::GenerateMask>(training, result_shape, element::f32, seed, 0.5, true);
auto f = make_shared<Function>(NodeVector{gen_mask, gen_mask2}, ParameterVector{});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto is_not_zero_or_one = [](float num) { return num != 0.f && num != 1.f; };
auto result_tv1 = backend->create_tensor<float>(result_shape);
auto result_tv2 = backend->create_tensor<float>(result_shape);
auto handle = backend->compile(f);
handle->call_with_validate({result_tv1, result_tv2}, {});
auto result1 = read_vector<float>(result_tv1);
auto result2 = read_vector<float>(result_tv2);
ASSERT_TRUE(test::all_close_f(result1, result2));
ASSERT_FALSE(std::any_of(result1.begin(), result1.end(), is_not_zero_or_one));
auto result_tv1_2 = backend->create_tensor<float>(result_shape);
auto result_tv2_2 = backend->create_tensor<float>(result_shape);
handle->call_with_validate({result_tv1_2, result_tv2_2}, {});
auto result1_2 = read_vector<float>(result_tv1_2);
auto result2_2 = read_vector<float>(result_tv2_2);
ASSERT_TRUE(test::all_close_f(result1, result1_2));
ASSERT_FALSE(std::any_of(result1_2.begin(), result1_2.end(), is_not_zero_or_one));
ASSERT_TRUE(test::all_close_f(result2, result2_2));
ASSERT_FALSE(std::any_of(result2_2.begin(), result2_2.end(), is_not_zero_or_one));
}
NGRAPH_TEST(${BACKEND_NAME}, quantize) NGRAPH_TEST(${BACKEND_NAME}, quantize)
{ {
Shape input_shape{4, 3}; Shape input_shape{4, 3};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment