Commit ab440246 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Add RandomUniform op (#3611)

* Add RandomUniform op

* Missing files

* Make CPU compile again (RandomUniform tests still don't pass)

* Bop users of existing RNGState class and reorder some comment junk

* Try to bop RNGState in gcpu

* Add RandomUniform to CPU manifest

* Add RandomUniform to PlaidML manifest

* Add .rst for random_uniform

* Clean up junk in the .rst

* Change UniformRNGState to always use double internally

* Change weird test failure message

* Compilation issues
parent f5c89181
...@@ -63,6 +63,7 @@ Not currently a comprehensive list. ...@@ -63,6 +63,7 @@ Not currently a comprehensive list.
* :doc:`power` * :doc:`power`
* :doc:`product` * :doc:`product`
* :doc:`quantize` * :doc:`quantize`
* :doc:`random_uniform`
* :doc:`relu` * :doc:`relu`
* :doc:`result` * :doc:`result`
* :doc:`shape_of` * :doc:`shape_of`
...@@ -136,6 +137,7 @@ Not currently a comprehensive list. ...@@ -136,6 +137,7 @@ Not currently a comprehensive list.
power.rst power.rst
product.rst product.rst
quantize.rst quantize.rst
random_uniform.rst
relu.rst relu.rst
result.rst result.rst
shape_of.rst shape_of.rst
......
.. random_uniform.rst:
#############
RandomUniform
#############
.. code-block:: cpp
RandomUniform // Operation that generates a tensor populated with random
// values of a uniform distribution.
Description
===========
.. warning:: This op is experimental and subject to change without notice.
Inputs
------
+--------------------+-------------------------+---------------------------------+-------------------------------------------+
| Name | Element Type | Shape | Notes |
+====================+=========================+=============================================================================+
| ``min_value`` | Any floating point type | Scalar | Minimum value for the random distribution |
+--------------------+-------------------------+---------------------------------+-------------------------------------------+
| ``max_value`` | Same as ``max_value`` | Scalar | Maximum value for the random distribution |
+--------------------+-------------------------+---------------------------------+-------------------------------------------+
| ``result_shape`` | ``element::i64`` | Vector of any size | Shape of the output tensor |
+--------------------+-------------------------+---------------------------------+-------------------------------------------+
| ``use_fixed_seed`` | ``element::boolean`` | Scalar | Flag indicating whether to use the fixed |
| | | | seed value ``fixed_seed`` (useful for |
| | | | testing) |
+--------------------+-------------------------+---------------------------------+-------------------------------------------+
Attributes
-----------
+---------------------+---------------+-----------------------------------------------------------------------------------------+
| Name | Type | Notes |
+=====================+===============+=========================================================================================+
| ``fixed_seed`` | ``uint64_t`` | Fixed seed value to use if ``use_fixed_seed`` flag is set to ``1``. This should be used |
| | | only for testing; if ``use_fixed_seed`` is ``1``, ``RandomUniform`` will produce the |
| | | _same_ values at each iteration. |
+---------------------+---------------+-----------------------------------------------------------------------------------------+
Outputs
-------
+-----------------+-------------------------+--------------------------------------------+
| Name | Element Type | Shape |
+=================+=========================+============================================+
| ``output`` | Same as ``min_value`` | ``result_shape`` |
+-----------------+-------------------------+--------------------------------------------+
Mathematical Definition
=======================
.. math::
\mathtt{output}_i = \mathtt{uniform_rand}(\mathtt{min}=\mathtt{min_value}, \mathtt{max}=\mathtt{max_value})
C++ Interface
=============
.. doxygenclass:: ngraph::op::RandomUniform
:project: ngraph
:members:
...@@ -196,6 +196,8 @@ set (SRC ...@@ -196,6 +196,8 @@ set (SRC
op/experimental/layers/reorg_yolo.cpp op/experimental/layers/reorg_yolo.cpp
op/experimental/layers/roi_pooling.hpp op/experimental/layers/roi_pooling.hpp
op/experimental/layers/roi_pooling.cpp op/experimental/layers/roi_pooling.cpp
op/experimental/random_uniform.hpp
op/experimental/random_uniform.cpp
op/floor.cpp op/floor.cpp
op/floor.hpp op/floor.hpp
op/gather.cpp op/gather.cpp
...@@ -507,7 +509,10 @@ set (SRC ...@@ -507,7 +509,10 @@ set (SRC
slice_plan.hpp slice_plan.hpp
specialize_function.cpp specialize_function.cpp
specialize_function.hpp specialize_function.hpp
state/rng_state.cpp state/bernoulli_rng_state.cpp
state/bernoulli_rng_state.hpp
state/uniform_rng_state.cpp
state/uniform_rng_state.hpp
strides.cpp strides.cpp
strides.hpp strides.hpp
type/bfloat16.cpp type/bfloat16.cpp
......
...@@ -118,6 +118,7 @@ namespace ngraph ...@@ -118,6 +118,7 @@ namespace ngraph
#include "ngraph/op/experimental/dyn_reshape.hpp" #include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp" #include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/generate_mask.hpp" #include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/experimental/range.hpp" #include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp" #include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp" #include "ngraph/op/experimental/tile.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/constant.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::RandomUniform::type_info;
op::RandomUniform::RandomUniform(const Output<Node>& min_value,
const Output<Node>& max_value,
const Output<Node>& result_shape,
const Output<Node>& use_fixed_seed,
uint64_t fixed_seed)
: Op({min_value, max_value, result_shape, use_fixed_seed})
, m_fixed_seed(fixed_seed)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::RandomUniform::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::RandomUniform>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), m_fixed_seed);
}
void ngraph::op::RandomUniform::validate_and_infer_types()
{
element::Type result_element_type;
NODE_VALIDATION_CHECK(this,
element::Type::merge(result_element_type,
input(0).get_element_type(),
input(1).get_element_type()),
"Element types for min and max values do not match.");
NODE_VALIDATION_CHECK(this,
result_element_type.is_dynamic() || result_element_type.is_real(),
"Element type of min_val and max_val inputs is not floating point.");
NODE_VALIDATION_CHECK(this,
input(0).get_partial_shape().compatible(Shape{}),
"Tensor for min_value is not a scalar.");
NODE_VALIDATION_CHECK(this,
input(1).get_partial_shape().compatible(Shape{}),
"Tensor for max_value is not a scalar.");
NODE_VALIDATION_CHECK(this,
input(2).get_element_type().compatible(element::i64),
"Element type for result_shape is not element::i64.");
NODE_VALIDATION_CHECK(this,
input(2).get_partial_shape().compatible(PartialShape::dynamic(1)),
"Tensor for result_shape not a vector.");
NODE_VALIDATION_CHECK(this,
input(3).get_element_type().compatible(element::boolean),
"Element type for use_fixed_seed is not element::boolean.");
NODE_VALIDATION_CHECK(this,
input(3).get_partial_shape().compatible(Shape{}),
"Tensor for use_fixed_seed is not a scalar.");
PartialShape result_shape;
if (auto result_shape_source_constant = as_type<op::Constant>(input_value(2).get_node()))
{
result_shape = result_shape_source_constant->get_shape_val();
}
else if (input(2).get_partial_shape().rank().is_static())
{
result_shape = PartialShape::dynamic(input(2).get_partial_shape()[0]);
}
else
{
result_shape = PartialShape::dynamic();
}
set_output_size(1);
set_input_is_relevant_to_shape(2);
set_output_type(0, result_element_type, result_shape);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Generates a tensor populated with random values of a uniform distribution.
class RandomUniform : public op::Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"RandomUniform", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an uninitialized RandomUniform node.
RandomUniform() = default;
/// \brief Constructs a RandomUniform node.
/// \param min_value Output producing the minimum value (inclusive) for the random
/// uniform distribution. Must return a scalar of floating point type,
/// and the type must match that of `max_value`.
/// \param max_value Output producing the maximum value (inclusive) for the random
/// uniform distribution. Must return a scalar of floating point type,
/// and the type must match that of `min_value`.
/// \param result_shape Output producing the shape of the output tensor. Must return a
/// vector of type `element::i64`.
/// \param use_fixed_seed Output producing a boolean scalar Flag indicating whether to
/// use the value supplied in `fixed_seed` to re-seed the random
/// number generator at this iteration. Note that whenever
/// `use_fixed_seed` is `true`, the same values will be generated
/// in the output tensor. This flag is primarily used for
/// debugging. If `use_fixed_seed` is `false`, the value in
/// `fixed_seed` is ignored.
/// \param fixed_seed Fixed seed value to be supplied to the random number generator if
/// `use_fixed_seed` is `true`. If `use_fixed_seed` is `false`, this
/// value is ignored.
RandomUniform(const Output<Node>& min_value,
const Output<Node>& max_value,
const Output<Node>& result_shape,
const Output<Node>& use_fixed_seed,
uint64_t fixed_seed);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \brief Returns the fixed seed value to be supplied to the random number generator
/// if `use_fixed_seed` is `true`. If `use_fixed_seed` is `false`, this value is
/// ignored.
uint64_t get_fixed_seed() const { return m_fixed_seed; }
/// \brief Sets the fixed seed value to be supplied to the random number generator
/// if `use_fixed_seed` is `true`. If `use_fixed_seed` is `false`, this value is
/// ignored.
void set_fixed_seed(uint64_t fixed_seed) { m_fixed_seed = fixed_seed; }
// Internally, any implementation of RandomUniform will have state, since it is backed
// by a random number generator.
bool has_state() const override { return true; }
void validate_and_infer_types() override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& /* adjoints */,
const NodeVector& /* deltas */) override
{
}
uint64_t m_fixed_seed;
};
}
}
...@@ -128,6 +128,7 @@ NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op) ...@@ -128,6 +128,7 @@ NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op)
NGRAPH_OP(QuantizedDot, ngraph::op) NGRAPH_OP(QuantizedDot, ngraph::op)
NGRAPH_OP(QuantizedDotBias, ngraph::op) NGRAPH_OP(QuantizedDotBias, ngraph::op)
NGRAPH_OP(Recv, ngraph::op) NGRAPH_OP(Recv, ngraph::op)
NGRAPH_OP(RandomUniform, ngraph::op)
NGRAPH_OP(Range, ngraph::op) NGRAPH_OP(Range, ngraph::op)
NGRAPH_OP(Relu, ngraph::op) NGRAPH_OP(Relu, ngraph::op)
NGRAPH_OP(ReluBackprop, ngraph::op) NGRAPH_OP(ReluBackprop, ngraph::op)
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include "ngraph/runtime/cpu/op/dropout.hpp" #include "ngraph/runtime/cpu/op/dropout.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp" #include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/dropout.hpp" #include "ngraph/runtime/cpu/kernel/dropout.hpp"
#include "ngraph/state/rng_state.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "ngraph/op/experimental/generate_mask.hpp" #include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp" #include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/reference/generate_mask.hpp" #include "ngraph/runtime/reference/generate_mask.hpp"
#include "ngraph/state/rng_state.hpp" #include "ngraph/state/bernoulli_rng_state.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -50,7 +50,7 @@ namespace ngraph ...@@ -50,7 +50,7 @@ namespace ngraph
auto seed_attr = gm->get_use_seed() ? gm->get_seed() : 0; auto seed_attr = gm->get_use_seed() ? gm->get_seed() : 0;
auto index = external_function->add_state( auto index = external_function->add_state(
ngraph::RNGState::create_rng_state(seed_attr, gm->get_probability())); new ngraph::BernoulliRNGState(seed_attr, gm->get_probability()));
if (args[0].get_element_type() == element::f32) if (args[0].get_element_type() == element::f32)
{ {
...@@ -77,7 +77,7 @@ namespace ngraph ...@@ -77,7 +77,7 @@ namespace ngraph
reference::generate_mask( reference::generate_mask(
static_cast<float*>(ctx->buffer_data[out_buffer_index]), static_cast<float*>(ctx->buffer_data[out_buffer_index]),
element_count, element_count,
static_cast<RNGState*>(ctx->states[index]), static_cast<BernoulliRNGState*>(ctx->states[index]),
training); training);
} }
else else
...@@ -116,7 +116,7 @@ namespace ngraph ...@@ -116,7 +116,7 @@ namespace ngraph
reference::generate_mask( reference::generate_mask(
static_cast<double*>(ctx->buffer_data[out_buffer_index]), static_cast<double*>(ctx->buffer_data[out_buffer_index]),
element_count, element_count,
static_cast<RNGState*>(ctx->states[index]), static_cast<BernoulliRNGState*>(ctx->states[index]),
training); training);
} }
else else
......
...@@ -130,7 +130,7 @@ ...@@ -130,7 +130,7 @@
#include "ngraph/runtime/cpu/op/rnn.hpp" #include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp" #include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp" #include "ngraph/runtime/cpu/op/update_slice.hpp"
#include "ngraph/state/rng_state.hpp" #include "ngraph/state/bernoulli_rng_state.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -4153,9 +4153,9 @@ namespace ngraph ...@@ -4153,9 +4153,9 @@ namespace ngraph
auto gm = static_cast<const ngraph::op::GenerateMask*>(node); auto gm = static_cast<const ngraph::op::GenerateMask*>(node);
writer.block_begin(); writer.block_begin();
auto index = external_function->add_state( auto index = external_function->add_state(
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability())); new ngraph::BernoulliRNGState(gm->get_seed(), gm->get_probability()));
writer << "auto state = static_cast<ngraph::RNGState*>(ctx->states[" << index writer << "auto state = static_cast<ngraph::BernoulliRNGState*>(ctx->states["
<< "]);\n"; << index << "]);\n";
writer << "bool training = static_cast<bool>(" << args[0].get_name() << "[0]);\n"; writer << "bool training = static_cast<bool>(" << args[0].get_name() << "[0]);\n";
writer << "bool use_seed = static_cast<bool>(" << args[2].get_name() << "[0]);\n"; writer << "bool use_seed = static_cast<bool>(" << args[2].get_name() << "[0]);\n";
......
...@@ -566,7 +566,7 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_ ...@@ -566,7 +566,7 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_
#include "ngraph/runtime/reference/topk.hpp" #include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/reference/xor.hpp" #include "ngraph/runtime/reference/xor.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/state/rng_state.hpp" #include "ngraph/state/bernoulli_rng_state.hpp"
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include <random> #include <random>
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/state/rng_state.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -21,3 +21,10 @@ lrn_across_all_dims ...@@ -21,3 +21,10 @@ lrn_across_all_dims
lrn_across_nw lrn_across_nw
lrn_across_empty lrn_across_empty
lrn_6D_across_2_axes lrn_6D_across_2_axes
# RandomUniform not supported in CPU backend
random_uniform_all_static_seed_unused
random_uniform_all_static_seed_used
random_uniform_seed_use_dynamic
random_uniform_all_static_range_dynamic
random_uniform_dynamic_shapes
...@@ -158,7 +158,7 @@ ...@@ -158,7 +158,7 @@
#include "ngraph/runtime/reference/topk.hpp" #include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/reference/xor.hpp" #include "ngraph/runtime/reference/xor.hpp"
#include "ngraph/runtime/tensor.hpp" #include "ngraph/runtime/tensor.hpp"
#include "ngraph/state/rng_state.hpp" #include "ngraph/state/bernoulli_rng_state.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -199,7 +199,7 @@ private: ...@@ -199,7 +199,7 @@ private:
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map; std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes; std::vector<NodeWrapper> m_wrapped_nodes;
std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states; std::unordered_map<const Node*, std::shared_ptr<ngraph::State>> m_states;
std::set<std::string> m_unsupported_op_name_list; std::set<std::string> m_unsupported_op_name_list;
static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&, static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
...@@ -379,12 +379,12 @@ private: ...@@ -379,12 +379,12 @@ private:
{ {
const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node); const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node);
auto seed = use_seed ? gm->get_seed() : 0; auto seed = use_seed ? gm->get_seed() : 0;
m_states[&node] = std::unique_ptr<ngraph::RNGState>( m_states[&node] = std::unique_ptr<ngraph::State>(
ngraph::RNGState::create_rng_state(seed, gm->get_probability())); new ngraph::BernoulliRNGState(seed, gm->get_probability()));
} }
bool training = static_cast<bool>(args[0]->get_data_ptr<const T>()[0]); bool training = static_cast<bool>(args[0]->get_data_ptr<const T>()[0]);
auto state = m_states.at(&node).get(); auto state = static_cast<ngraph::BernoulliRNGState*>(m_states.at(&node).get());
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
if (!use_seed) if (!use_seed)
{ {
......
...@@ -46,6 +46,7 @@ ...@@ -46,6 +46,7 @@
#include "ngraph/op/experimental/dyn_broadcast.hpp" #include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp" #include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/generate_mask.hpp" #include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/experimental/shape_of.hpp" #include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
...@@ -146,6 +147,7 @@ ...@@ -146,6 +147,7 @@
#include "ngraph/runtime/reference/power.hpp" #include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp" #include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp" #include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/random_uniform.hpp"
#include "ngraph/runtime/reference/recv.hpp" #include "ngraph/runtime/reference/recv.hpp"
#include "ngraph/runtime/reference/relu.hpp" #include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp" #include "ngraph/runtime/reference/replace_slice.hpp"
...@@ -172,7 +174,8 @@ ...@@ -172,7 +174,8 @@
#include "ngraph/runtime/reference/topk.hpp" #include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/reference/xor.hpp" #include "ngraph/runtime/reference/xor.hpp"
#include "ngraph/runtime/tensor.hpp" #include "ngraph/runtime/tensor.hpp"
#include "ngraph/state/rng_state.hpp" #include "ngraph/state/bernoulli_rng_state.hpp"
#include "ngraph/state/uniform_rng_state.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -195,7 +198,7 @@ public: ...@@ -195,7 +198,7 @@ public:
bool enable_performance_collection = false); bool enable_performance_collection = false);
bool call(const std::vector<std::shared_ptr<Tensor>>& outputs, bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
const std::vector<std::shared_ptr<Tensor>>& intputs) override; const std::vector<std::shared_ptr<Tensor>>& inputs) override;
virtual void save(std::ostream& output_stream) override; virtual void save(std::ostream& output_stream) override;
...@@ -225,7 +228,7 @@ private: ...@@ -225,7 +228,7 @@ private:
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map; std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes; std::vector<NodeWrapper> m_wrapped_nodes;
std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states; std::unordered_map<const Node*, std::shared_ptr<State>> m_states;
std::set<std::string> m_unsupported_op_name_list; std::set<std::string> m_unsupported_op_name_list;
static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&, static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
...@@ -419,12 +422,12 @@ private: ...@@ -419,12 +422,12 @@ private:
{ {
const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node); const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node);
auto seed = use_seed ? gm->get_seed() : 0; auto seed = use_seed ? gm->get_seed() : 0;
m_states[&node] = std::unique_ptr<ngraph::RNGState>( m_states[&node] =
ngraph::RNGState::create_rng_state(seed, gm->get_probability())); std::unique_ptr<State>(new BernoulliRNGState(seed, gm->get_probability()));
} }
bool training = static_cast<bool>(args[0]->get_data_ptr<const T>()[0]); bool training = static_cast<bool>(args[0]->get_data_ptr<const T>()[0]);
auto state = m_states.at(&node).get(); auto state = static_cast<BernoulliRNGState*>(m_states.at(&node).get());
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
if (!use_seed) if (!use_seed)
{ {
...@@ -1449,6 +1452,38 @@ private: ...@@ -1449,6 +1452,38 @@ private:
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize); memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break; break;
} }
case OP_TYPEID::RandomUniform:
{
const op::RandomUniform* ru = static_cast<const op::RandomUniform*>(&node);
T min_val = args[0]->get_data_ptr<const T>()[0];
T max_val = args[1]->get_data_ptr<const T>()[0];
// In INTERPRETER we can ignore arg 2 (output_shape) for now because we only work on
// static output shapes anyway.
bool use_fixed_seed = static_cast<bool>(args[3]->get_data_ptr<const char>()[0]);
if (m_states.count(&node) == 0)
{
m_states[&node] = std::unique_ptr<UniformRNGState>(new UniformRNGState());
}
auto state = static_cast<UniformRNGState*>(m_states.at(&node).get());
size_t element_count = shape_size(node.get_output_shape(0));
if (!use_fixed_seed)
{
reference::random_uniform<T>(
out[0]->get_data_ptr<T>(), min_val, max_val, element_count, state);
}
else
{
reference::random_uniform_with_fixed_seed<T>(out[0]->get_data_ptr<T>(),
min_val,
max_val,
element_count,
ru->get_fixed_seed());
}
break;
}
case OP_TYPEID::Range: case OP_TYPEID::Range:
{ {
throw unsupported_op("Unsupported op '" + node.description() + "'"); throw unsupported_op("Unsupported op '" + node.description() + "'");
......
...@@ -272,3 +272,10 @@ lrn_across_all_dims ...@@ -272,3 +272,10 @@ lrn_across_all_dims
lrn_across_nw lrn_across_nw
lrn_across_empty lrn_across_empty
lrn_6D_across_2_axes lrn_6D_across_2_axes
# RandomUniform not supported in PlaidML backend
random_uniform_all_static_seed_unused
random_uniform_all_static_seed_used
random_uniform_seed_use_dynamic
random_uniform_all_static_range_dynamic
random_uniform_dynamic_shapes
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <random> #include <random>
#include "ngraph/state/rng_state.hpp" #include "ngraph/state/bernoulli_rng_state.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -27,7 +27,10 @@ namespace ngraph ...@@ -27,7 +27,10 @@ namespace ngraph
namespace reference namespace reference
{ {
template <typename T> template <typename T>
void generate_mask(T* out, size_t count, ngraph::RNGState* rng_state, bool training) void generate_mask(T* out,
size_t count,
ngraph::BernoulliRNGState* rng_state,
bool training)
{ {
auto& gen = rng_state->get_generator(); auto& gen = rng_state->get_generator();
auto& bd = rng_state->get_distribution(); auto& bd = rng_state->get_distribution();
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <random>
#include "ngraph/state/uniform_rng_state.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void random_uniform(
T* out, T min_val, T max_val, size_t count, ngraph::UniformRNGState* rng_state)
{
auto& gen = rng_state->get_generator();
auto& bd = rng_state->get_distribution();
for (size_t i = 0; i < count; i++)
{
out[i] = static_cast<T>(bd(gen)) * (max_val - min_val) + min_val;
}
}
template <typename T>
void random_uniform_with_fixed_seed(
T* out, T min_val, T max_val, size_t count, size_t fixed_seed)
{
ngraph::UniformRNGState rng_state(fixed_seed);
random_uniform(out, min_val, max_val, count, &rng_state);
}
}
}
}
...@@ -62,6 +62,7 @@ ...@@ -62,6 +62,7 @@
#include "ngraph/op/experimental/quantized_conv_bias.hpp" #include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp" #include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp" #include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/experimental/range.hpp" #include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp" #include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp" #include "ngraph/op/experimental/tile.hpp"
...@@ -1708,6 +1709,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1708,6 +1709,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Recv>(args[0], src_id); node = make_shared<op::Recv>(args[0], src_id);
break; break;
} }
case OP_TYPEID::RandomUniform:
{
auto fixed_seed = node_js.at("fixed_seed").get<uint64_t>();
node = make_shared<op::RandomUniform>(args[0], args[1], args[2], args[3], fixed_seed);
break;
}
case OP_TYPEID::Range: case OP_TYPEID::Range:
{ {
node = make_shared<op::Range>(args[0], args[1], args[2]); node = make_shared<op::Range>(args[0], args[1], args[2]);
...@@ -2795,6 +2802,12 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2795,6 +2802,12 @@ json JSONSerializer::serialize_node(const Node& n)
node["source_id"] = tmp->get_src_id(); node["source_id"] = tmp->get_src_id();
break; break;
} }
case OP_TYPEID::RandomUniform:
{
auto tmp = dynamic_cast<const op::RandomUniform*>(&n);
node["fixed_seed"] = tmp->get_fixed_seed();
break;
}
case OP_TYPEID::Range: { break; case OP_TYPEID::Range: { break;
} }
case OP_TYPEID::Relu: { break; case OP_TYPEID::Relu: { break;
......
...@@ -16,16 +16,16 @@ ...@@ -16,16 +16,16 @@
#include <random> #include <random>
#include "bernoulli_rng_state.hpp"
#include "except.hpp" #include "except.hpp"
#include "rng_state.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
void ngraph::RNGState::activate() void ngraph::BernoulliRNGState::activate()
{ {
} }
void ngraph::RNGState::deactivate() void ngraph::BernoulliRNGState::deactivate()
{ {
} }
...@@ -24,17 +24,10 @@ ...@@ -24,17 +24,10 @@
namespace ngraph namespace ngraph
{ {
// can be based on TensorSate to cache values instead of just caching seed class BernoulliRNGState : public State
class RNGState : public State
{ {
public: public:
static RNGState* create_rng_state(unsigned int seed, double probability) BernoulliRNGState(unsigned int seed, double probability)
{
auto rng = new RNGState(seed, probability);
return rng;
}
RNGState(unsigned int seed, double probability)
: State() : State()
, m_generator(seed) , m_generator(seed)
, m_distribution(probability) , m_distribution(probability)
...@@ -42,7 +35,7 @@ namespace ngraph ...@@ -42,7 +35,7 @@ namespace ngraph
} }
virtual void activate() override; virtual void activate() override;
virtual void deactivate() override; virtual void deactivate() override;
virtual ~RNGState() override {} virtual ~BernoulliRNGState() override {}
std::mt19937& get_generator() { return m_generator; } std::mt19937& get_generator() { return m_generator; }
std::bernoulli_distribution& get_distribution() { return m_distribution; } std::bernoulli_distribution& get_distribution() { return m_distribution; }
protected: protected:
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/state/uniform_rng_state.hpp"
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <memory>
#include <random>
#include "state.hpp"
namespace ngraph
{
class UniformRNGState : public State
{
public:
UniformRNGState(std::mt19937::result_type seed)
: State()
, m_generator(std::mt19937::result_type(seed))
, m_distribution()
{
}
UniformRNGState()
: State()
, m_generator(std::random_device()())
, m_distribution()
{
}
virtual void activate() override {}
virtual void deactivate() override {}
virtual ~UniformRNGState() override {}
std::mt19937& get_generator() { return m_generator; }
std::uniform_real_distribution<double>& get_distribution() { return m_distribution; }
private:
std::mt19937 m_generator;
std::uniform_real_distribution<double> m_distribution;
};
}
...@@ -133,6 +133,7 @@ set(SRC ...@@ -133,6 +133,7 @@ set(SRC
type_prop/quantize.cpp type_prop/quantize.cpp
type_prop/quantized_convolution.cpp type_prop/quantized_convolution.cpp
type_prop/quantized_dot.cpp type_prop/quantized_dot.cpp
type_prop/random_uniform.cpp
type_prop/range.cpp type_prop/range.cpp
type_prop/replace_slice.cpp type_prop/replace_slice.cpp
type_prop/reshape.cpp type_prop/reshape.cpp
...@@ -289,6 +290,7 @@ set(MULTI_TEST_SRC ...@@ -289,6 +290,7 @@ set(MULTI_TEST_SRC
backend/quantize_dequantize.in.cpp backend/quantize_dequantize.in.cpp
backend/quantized_convolution.in.cpp backend/quantized_convolution.in.cpp
backend/quantized_dot.in.cpp backend/quantized_dot.in.cpp
backend/random_uniform.in.cpp
backend/range.in.cpp backend/range.in.cpp
backend/relu.in.cpp backend/relu.in.cpp
backend/replace_slice.in.cpp backend/replace_slice.in.cpp
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment