Commit 167a6530 authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Support for v1 GenerateMask op ( dynamic shape ) (#3779)

* Dynshape support for GenerateMask

* fix clang error

* Remove comments

* Test case correction

* Disable plaidml

* Merge
parent c5998d23
...@@ -20,28 +20,14 @@ ...@@ -20,28 +20,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::GenerateMask::type_info; constexpr NodeTypeInfo op::v0::GenerateMask::type_info;
#if 0 op::v0::GenerateMask::GenerateMask(const Output<Node>& training,
// Not supported until all transformers use nodes instead of attributes const Shape& shape,
op::GenerateMask::GenerateMask(const Output<Node>& training, const element::Type& element_type,
const Output<Node>& shape, uint64_t seed,
const Output<Node>& probability, double prob,
const Output<Node>& seed, bool use_seed)
const Output<Node>& use_seed,
const element::Type& element_type)
: Op({training, shape, probability, seed, use_seed})
, m_element_type(element_type)
{
}
#endif
op::GenerateMask::GenerateMask(const Output<Node>& training,
const Shape& shape,
const element::Type& element_type,
uint64_t seed,
double prob,
bool use_seed)
: Op({training}) : Op({training})
, m_element_type(element_type) , m_element_type(element_type)
, m_shape(shape) , m_shape(shape)
...@@ -61,14 +47,14 @@ op::GenerateMask::GenerateMask(const Output<Node>& training, ...@@ -61,14 +47,14 @@ op::GenerateMask::GenerateMask(const Output<Node>& training,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::GenerateMask::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::GenerateMask::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<GenerateMask>( return make_shared<v0::GenerateMask>(
new_args.at(0), m_shape, m_element_type, m_seed, m_probability, m_use_seed); new_args.at(0), m_shape, m_element_type, m_seed, m_probability, m_use_seed);
} }
void ngraph::op::GenerateMask::validate_and_infer_types() void ngraph::op::v0::GenerateMask::validate_and_infer_types()
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
get_input_partial_shape(0).compatible(PartialShape{}), get_input_partial_shape(0).compatible(PartialShape{}),
...@@ -79,3 +65,65 @@ void ngraph::op::GenerateMask::validate_and_infer_types() ...@@ -79,3 +65,65 @@ void ngraph::op::GenerateMask::validate_and_infer_types()
set_output_type(0, m_element_type, m_shape); set_output_type(0, m_element_type, m_shape);
} }
// V1 version starts
constexpr NodeTypeInfo op::v1::GenerateMask::type_info;
op::v1::GenerateMask::GenerateMask(const Output<Node>& training,
const Output<Node>& shape,
const element::Type& element_type,
uint64_t seed,
double prob,
bool use_seed)
: Op({training, shape})
, m_element_type(element_type)
, m_use_seed(use_seed)
, m_seed(seed)
, m_probability(prob)
{
set_argument(2,
make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{use_seed}));
set_argument(3, make_shared<op::Constant>(element::u64, Shape{}, std::vector<uint64_t>{seed}));
set_argument(4, make_shared<op::Constant>(element::f64, Shape{}, std::vector<double>{prob}));
add_provenance_group_member(input_value(2).get_node_shared_ptr());
add_provenance_group_member(input_value(3).get_node_shared_ptr());
add_provenance_group_member(input_value(4).get_node_shared_ptr());
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v1::GenerateMask::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::GenerateMask>(
new_args.at(0), new_args.at(1), m_element_type, m_seed, m_probability, m_use_seed);
}
const Shape op::v1::GenerateMask::get_mask_shape() const
{
Shape shape;
if (auto const_op = as_type<op::Constant>(input_value(1).get_node()))
{
shape = const_op->get_shape_val();
}
return shape;
}
void ngraph::op::v1::GenerateMask::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(0).compatible(PartialShape{}),
"Training node should be a scalar flag indicating a mode");
NODE_VALIDATION_CHECK(
this, m_element_type.is_static(), "Output element type must not be dynamic.");
PartialShape mask_shape{PartialShape::dynamic()};
if (input_value(1).get_node_shared_ptr()->is_constant())
{
mask_shape = get_mask_shape();
}
set_input_is_relevant_to_shape(1);
set_output_type(0, m_element_type, mask_shape);
}
...@@ -24,70 +24,124 @@ namespace ngraph ...@@ -24,70 +24,124 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief GenerateMask namespace v0
///
class GenerateMask : public op::Op
{ {
public: /// \brief GenerateMask
NGRAPH_API ///
static constexpr NodeTypeInfo type_info{"GenerateMask", 0}; class GenerateMask : public op::Op
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a GenerateMask node with a given shape, seed,
/// probability and training/inference mode
GenerateMask() = default;
#if 0
/// Switch to dynamic arguments when all transformers have switched to using the node values
/// \brief Constructs a GenerateMask node with a given shape, seed,
/// probability and training/inference mode
GenerateMask(const Output<Node>& training,
const Output<Node>& shape,
const Output<Node>& probability,
const Output<Node>& seed,
const Output<Node>& use_seed,
const element::Type& element_type);
#endif
/// \brief Constructs a GenerateMask node with a given shape, seed,
/// probability and training/inference mode
GenerateMask(const Output<Node>& training,
const Shape& shape,
const element::Type& element_type,
uint64_t seed,
double prob,
bool use_seed = false);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const element::Type& get_element_type() const { return m_element_type; }
void set_element_type(const element::Type& element_type)
{ {
m_element_type = element_type; public:
} NGRAPH_API
static constexpr NodeTypeInfo type_info{"GenerateMask", 0};
/// Deprecated accessor for transitional attributes const NodeTypeInfo& get_type_info() const override { return type_info; }
const Shape& get_mask_shape() const { return m_shape; } /// \brief Constructs a GenerateMask node with a given shape, seed,
/// \brief Returns the probability of a trial generating 1 (i.e. an element being kept) /// probability and training/inference mode
double get_probability() const { return m_probability; } GenerateMask() = default;
/// \brief Returns the seed value supplied to a random generator
uint64_t get_seed() const { return m_seed; } /// \brief Constructs a GenerateMask node with a given shape, seed,
bool get_use_seed() const { return m_use_seed; } /// probability and training/inference mode
/// GenerateMask has state. GenerateMask(const Output<Node>& training,
bool has_state() const override { return true; } const Shape& shape,
void validate_and_infer_types() override; const element::Type& element_type,
uint64_t seed,
protected: double prob,
virtual void generate_adjoints(autodiff::Adjoints& /* adjoints */, bool use_seed = false);
const NodeVector& /* deltas */) override
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const element::Type& get_element_type() const { return m_element_type; }
void set_element_type(const element::Type& element_type)
{
m_element_type = element_type;
}
/// Deprecated accessor for transitional attributes
const Shape& get_mask_shape() const { return m_shape; }
/// \brief Returns the probability of a trial generating 1 (i.e. an element being
/// kept)
double get_probability() const { return m_probability; }
/// \brief Returns the seed value supplied to a random generator
uint64_t get_seed() const { return m_seed; }
bool get_use_seed() const { return m_use_seed; }
/// GenerateMask has state.
bool has_state() const override { return true; }
void validate_and_infer_types() override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& /* adjoints */,
const NodeVector& /* deltas */) override
{
}
element::Type m_element_type;
// These will be deprecated
Shape m_shape;
bool m_use_seed{false};
uint64_t m_seed{0};
double m_probability{0.0};
};
} // namespace v0
namespace v1
{
/// \brief GenerateMask
///
class GenerateMask : public op::Op
{ {
} public:
NGRAPH_API
element::Type m_element_type; static constexpr NodeTypeInfo type_info{"GenerateMask", 1};
// These will be deprecated const NodeTypeInfo& get_type_info() const override { return type_info; }
Shape m_shape; /// \brief Constructs a GenerateMask node with a given shape, seed,
bool m_use_seed{false}; /// probability and training/inference mode
uint64_t m_seed{0}; GenerateMask() = default;
double m_probability{0.0};
}; /// \brief Constructs a GenerateMask node with a given shape, seed,
} /// probability and training/inference mode
} GenerateMask(const Output<Node>& training,
const Output<Node>& shape,
const element::Type& element_type,
uint64_t seed,
double prob,
bool use_seed = false);
size_t get_version() const override { return 1; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const element::Type& get_element_type() const { return m_element_type; }
void set_element_type(const element::Type& element_type)
{
m_element_type = element_type;
}
/// Deprecated accessor for transitional attributes
const Shape get_mask_shape() const;
/// \brief Returns the probability of a trial generating 1 (i.e. an element being
/// kept)
double get_probability() const { return m_probability; }
/// \brief Returns the seed value supplied to a random generator
uint64_t get_seed() const { return m_seed; }
bool get_use_seed() const { return m_use_seed; }
/// GenerateMask has state.
bool has_state() const override { return true; }
void validate_and_infer_types() override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& /* adjoints */,
const NodeVector& /* deltas */) override
{
}
element::Type m_element_type;
// These will be deprecated
bool m_use_seed{false};
uint64_t m_seed{0};
double m_probability{0.0};
};
} // namespace v1
using v0::GenerateMask;
} // op
} // ngraph
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp" #include "ngraph/op/max_pool.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
...@@ -402,6 +403,27 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node) ...@@ -402,6 +403,27 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::GenerateMask:
{
auto tmp = dynamic_cast<const op::v1::GenerateMask*>(node.get());
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
auto mask_shape =
static_pointer_cast<op::Constant>(node->input_value(1).get_node_shared_ptr())
->get_shape_val();
auto seed = tmp->get_seed();
auto use_seed = tmp->get_use_seed();
auto probability = tmp->get_probability();
auto et = tmp->get_element_type();
auto replacement_node = make_shared<op::v0::GenerateMask>(
node->input(0).get_source_output(), mask_shape, et, seed, probability, use_seed);
replace_node(node, replacement_node);
modified = true;
break;
}
default: break; default: break;
} }
#if defined(__clang__) #if defined(__clang__)
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/max.hpp" #include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp" #include "ngraph/op/max_pool.hpp"
...@@ -144,7 +145,6 @@ namespace ngraph ...@@ -144,7 +145,6 @@ namespace ngraph
class Or; class Or;
class Xor; class Xor;
class CompiledKernel; class CompiledKernel;
class GenerateMask;
class Dropout; class Dropout;
class Dequantize; class Dequantize;
class Quantize; class Quantize;
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ngraph/op/experimental/dyn_replace_slice.hpp" #include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp" #include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp" #include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/range.hpp" #include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/transpose.hpp" #include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
...@@ -89,7 +90,8 @@ bool is_dynamic_op(const std::shared_ptr<Node>& op) ...@@ -89,7 +90,8 @@ bool is_dynamic_op(const std::shared_ptr<Node>& op)
return is_type<op::Transpose>(op) || is_type<op::DynBroadcast>(op) || return is_type<op::Transpose>(op) || is_type<op::DynBroadcast>(op) ||
is_type<op::DynReplaceSlice>(op) || is_type<op::DynSlice>(op) || is_type<op::DynReplaceSlice>(op) || is_type<op::DynSlice>(op) ||
is_type<op::v1::Reshape>(op) || is_type<op::DynReshape>(op) || is_type<op::Range>(op) || is_type<op::v1::Reshape>(op) || is_type<op::DynReshape>(op) || is_type<op::Range>(op) ||
is_type<op::v1::AvgPoolBackprop>(op) || is_type<op::v1::Broadcast>(op); is_type<op::v1::GenerateMask>(op) || is_type<op::v1::AvgPoolBackprop>(op) ||
is_type<op::v1::Broadcast>(op);
} }
// Helper for a vile hack in DynamicExecutable::call. See body of that function for details. // Helper for a vile hack in DynamicExecutable::call. See body of that function for details.
......
...@@ -308,3 +308,6 @@ convert_bf16_float32 ...@@ -308,3 +308,6 @@ convert_bf16_float32
# infinitive values are returned for below cases # infinitive values are returned for below cases
normalize_across_c_2x2_shape normalize_across_c_2x2_shape
normalize_across_c_2x4_shape normalize_across_c_2x4_shape
# dyn shape
dyn_generate_mask
...@@ -1354,14 +1354,24 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1354,14 +1354,24 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
} }
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
{ {
auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
auto type = read_element_type(node_js.at("type")); auto type = read_element_type(node_js.at("type"));
auto seed = node_js.at("seed").get<unsigned int>(); auto seed = node_js.at("seed").get<unsigned int>();
auto probability = node_js.at("probability").get<double>(); auto probability = node_js.at("probability").get<double>();
bool use_seed = get_or_default<bool>(node_js, "use_seed", false); bool use_seed = get_or_default<bool>(node_js, "use_seed", false);
node = make_shared<op::GenerateMask>( if (op_version == 0)
args[0], output_shape, type, seed, probability, use_seed); {
auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
node = make_shared<op::v0::GenerateMask>(
args[0], output_shape, type, seed, probability, use_seed);
}
if (op_version == 1)
{
node = make_shared<op::v1::GenerateMask>(
args[0], args[1], type, seed, probability, use_seed);
}
break; break;
} }
case OP_TYPEID::GetOutputElement: case OP_TYPEID::GetOutputElement:
...@@ -2868,11 +2878,15 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2868,11 +2878,15 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
{ {
auto tmp = static_cast<const op::GenerateMask*>(&n); auto tmp = static_cast<const op::GenerateMask*>(&n);
node["output_shape"] = tmp->get_mask_shape();
node["type"] = write_element_type(tmp->get_element_type()); node["type"] = write_element_type(tmp->get_element_type());
node["use_seed"] = tmp->get_use_seed(); node["use_seed"] = tmp->get_use_seed();
node["seed"] = tmp->get_seed(); node["seed"] = tmp->get_seed();
node["probability"] = tmp->get_probability(); node["probability"] = tmp->get_probability();
if (op_version == 0)
{
node["output_shape"] = tmp->get_mask_shape();
}
break; break;
} }
case OP_TYPEID::Greater: case OP_TYPEID::Greater:
......
...@@ -71,6 +71,7 @@ set(SRC ...@@ -71,6 +71,7 @@ set(SRC
op.cpp op.cpp
opset_pass/broadcast_opset_pass.cpp opset_pass/broadcast_opset_pass.cpp
opset_pass/convolution_opset_pass.cpp opset_pass/convolution_opset_pass.cpp
opset_pass/generate_mask_opset_pass.cpp
opset_pass/gather_opset_pass.cpp opset_pass/gather_opset_pass.cpp
opset_pass/pad_opset_pass.cpp opset_pass/pad_opset_pass.cpp
opset_pass/poolings_opset_pass.cpp opset_pass/poolings_opset_pass.cpp
......
...@@ -96,3 +96,44 @@ NGRAPH_TEST(${BACKEND_NAME}, generate_mask2) ...@@ -96,3 +96,44 @@ NGRAPH_TEST(${BACKEND_NAME}, generate_mask2)
ASSERT_TRUE(test::all_close_f(result2, result2_2)); ASSERT_TRUE(test::all_close_f(result2, result2_2));
ASSERT_FALSE(std::any_of(result2_2.begin(), result2_2.end(), is_not_zero_or_one)); ASSERT_FALSE(std::any_of(result2_2.begin(), result2_2.end(), is_not_zero_or_one));
} }
NGRAPH_TEST(${BACKEND_NAME}, dyn_generate_mask)
{
const unsigned int seed = 777;
auto training = op::Constant::create(element::f32, Shape{}, {1});
auto result_shape =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto gen_mask =
make_shared<op::v1::GenerateMask>(training, result_shape, element::f32, seed, 0.5, true);
auto gen_mask2 =
make_shared<op::v1::GenerateMask>(training, result_shape, element::f32, seed, 0.5, true);
auto f = make_shared<Function>(NodeVector{gen_mask, gen_mask2}, ParameterVector{result_shape});
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
auto is_not_zero_or_one = [](float num) { return num != 0.f && num != 1.f; };
vector<int64_t> shapes = {1, 128};
auto shape_result = backend->create_tensor(element::i64, Shape{shapes.size()});
copy_data(shape_result, shapes);
auto result_tv1 = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
auto result_tv2 = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
auto handle = backend->compile(f);
handle->call_with_validate({result_tv1, result_tv2}, {shape_result});
ASSERT_EQ(result_tv1->get_shape(), (Shape{1, 128}));
ASSERT_EQ(result_tv2->get_shape(), (Shape{1, 128}));
auto result1 = read_vector<float>(result_tv1);
auto result2 = read_vector<float>(result_tv2);
ASSERT_TRUE(test::all_close_f(result1, result2));
ASSERT_FALSE(std::any_of(result1.begin(), result1.end(), is_not_zero_or_one));
auto result_tv1_2 = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
auto result_tv2_2 = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
handle->call_with_validate({result_tv1_2, result_tv2_2}, {shape_result});
auto result1_2 = read_vector<float>(result_tv1_2);
auto result2_2 = read_vector<float>(result_tv2_2);
ASSERT_TRUE(test::all_close_f(result1, result1_2));
ASSERT_FALSE(std::any_of(result1_2.begin(), result1_2.end(), is_not_zero_or_one));
ASSERT_TRUE(test::all_close_f(result2, result2_2));
ASSERT_FALSE(std::any_of(result2_2.begin(), result2_2.end(), is_not_zero_or_one));
}
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset1_generate_mask_downgrade_pass)
{
Shape scalar{};
const unsigned int seed = 777;
auto training = op::Constant::create(element::f32, Shape{}, {1});
auto result_shape = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 128});
auto gen_mask =
make_shared<op::v1::GenerateMask>(training, result_shape, element::f32, seed, 0.5, false);
auto gen_mask2 =
make_shared<op::v1::GenerateMask>(training, result_shape, element::f32, seed, 0.5, false);
auto f = make_shared<Function>(NodeVector{gen_mask, gen_mask2}, ParameterVector{});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
auto generate_mask_v0 = static_pointer_cast<op::v0::GenerateMask>(
f->get_results().at(0)->input_value(0).get_node_shared_ptr());
EXPECT_EQ(generate_mask_v0->description(), "GenerateMask");
EXPECT_EQ(generate_mask_v0->get_version(), 0);
EXPECT_EQ(generate_mask_v0->get_mask_shape(), (Shape{1, 128}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment