Commit ab440246 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Add RandomUniform op (#3611)

* Add RandomUniform op

* Missing files

* Make CPU compile again (RandomUniform tests still don't pass)

* Bop users of existing RNGState class and reorder some comment junk

* Try to bop RNGState in gcpu

* Add RandomUniform to CPU manifest

* Add RandomUniform to PlaidML manifest

* Add .rst for random_uniform

* Clean up junk in the .rst

* Change UniformRNGState to always use double internally

* Change weird test failure message

* Compilation issues
parent f5c89181
......@@ -63,6 +63,7 @@ Not currently a comprehensive list.
* :doc:`power`
* :doc:`product`
* :doc:`quantize`
* :doc:`random_uniform`
* :doc:`relu`
* :doc:`result`
* :doc:`shape_of`
......@@ -136,6 +137,7 @@ Not currently a comprehensive list.
power.rst
product.rst
quantize.rst
random_uniform.rst
relu.rst
result.rst
shape_of.rst
......
.. random_uniform.rst:
#############
RandomUniform
#############
.. code-block:: cpp
RandomUniform // Operation that generates a tensor populated with random
// values of a uniform distribution.
Description
===========
.. warning:: This op is experimental and subject to change without notice.
Inputs
------
+--------------------+-------------------------+---------------------------------+-------------------------------------------+
| Name | Element Type | Shape | Notes |
+====================+=========================+=============================================================================+
| ``min_value`` | Any floating point type | Scalar | Minimum value for the random distribution |
+--------------------+-------------------------+---------------------------------+-------------------------------------------+
| ``max_value`` | Same as ``max_value`` | Scalar | Maximum value for the random distribution |
+--------------------+-------------------------+---------------------------------+-------------------------------------------+
| ``result_shape`` | ``element::i64`` | Vector of any size | Shape of the output tensor |
+--------------------+-------------------------+---------------------------------+-------------------------------------------+
| ``use_fixed_seed`` | ``element::boolean`` | Scalar | Flag indicating whether to use the fixed |
| | | | seed value ``fixed_seed`` (useful for |
| | | | testing) |
+--------------------+-------------------------+---------------------------------+-------------------------------------------+
Attributes
-----------
+---------------------+---------------+-----------------------------------------------------------------------------------------+
| Name | Type | Notes |
+=====================+===============+=========================================================================================+
| ``fixed_seed`` | ``uint64_t`` | Fixed seed value to use if ``use_fixed_seed`` flag is set to ``1``. This should be used |
| | | only for testing; if ``use_fixed_seed`` is ``1``, ``RandomUniform`` will produce the |
| | | _same_ values at each iteration. |
+---------------------+---------------+-----------------------------------------------------------------------------------------+
Outputs
-------
+-----------------+-------------------------+--------------------------------------------+
| Name | Element Type | Shape |
+=================+=========================+============================================+
| ``output`` | Same as ``min_value`` | ``result_shape`` |
+-----------------+-------------------------+--------------------------------------------+
Mathematical Definition
=======================
.. math::
\mathtt{output}_i = \mathtt{uniform_rand}(\mathtt{min}=\mathtt{min_value}, \mathtt{max}=\mathtt{max_value})
C++ Interface
=============
.. doxygenclass:: ngraph::op::RandomUniform
:project: ngraph
:members:
......@@ -196,6 +196,8 @@ set (SRC
op/experimental/layers/reorg_yolo.cpp
op/experimental/layers/roi_pooling.hpp
op/experimental/layers/roi_pooling.cpp
op/experimental/random_uniform.hpp
op/experimental/random_uniform.cpp
op/floor.cpp
op/floor.hpp
op/gather.cpp
......@@ -507,7 +509,10 @@ set (SRC
slice_plan.hpp
specialize_function.cpp
specialize_function.hpp
state/rng_state.cpp
state/bernoulli_rng_state.cpp
state/bernoulli_rng_state.hpp
state/uniform_rng_state.cpp
state/uniform_rng_state.hpp
strides.cpp
strides.hpp
type/bfloat16.cpp
......
......@@ -118,6 +118,7 @@ namespace ngraph
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/constant.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::RandomUniform::type_info;
op::RandomUniform::RandomUniform(const Output<Node>& min_value,
const Output<Node>& max_value,
const Output<Node>& result_shape,
const Output<Node>& use_fixed_seed,
uint64_t fixed_seed)
: Op({min_value, max_value, result_shape, use_fixed_seed})
, m_fixed_seed(fixed_seed)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::RandomUniform::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::RandomUniform>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), m_fixed_seed);
}
void ngraph::op::RandomUniform::validate_and_infer_types()
{
element::Type result_element_type;
NODE_VALIDATION_CHECK(this,
element::Type::merge(result_element_type,
input(0).get_element_type(),
input(1).get_element_type()),
"Element types for min and max values do not match.");
NODE_VALIDATION_CHECK(this,
result_element_type.is_dynamic() || result_element_type.is_real(),
"Element type of min_val and max_val inputs is not floating point.");
NODE_VALIDATION_CHECK(this,
input(0).get_partial_shape().compatible(Shape{}),
"Tensor for min_value is not a scalar.");
NODE_VALIDATION_CHECK(this,
input(1).get_partial_shape().compatible(Shape{}),
"Tensor for max_value is not a scalar.");
NODE_VALIDATION_CHECK(this,
input(2).get_element_type().compatible(element::i64),
"Element type for result_shape is not element::i64.");
NODE_VALIDATION_CHECK(this,
input(2).get_partial_shape().compatible(PartialShape::dynamic(1)),
"Tensor for result_shape not a vector.");
NODE_VALIDATION_CHECK(this,
input(3).get_element_type().compatible(element::boolean),
"Element type for use_fixed_seed is not element::boolean.");
NODE_VALIDATION_CHECK(this,
input(3).get_partial_shape().compatible(Shape{}),
"Tensor for use_fixed_seed is not a scalar.");
PartialShape result_shape;
if (auto result_shape_source_constant = as_type<op::Constant>(input_value(2).get_node()))
{
result_shape = result_shape_source_constant->get_shape_val();
}
else if (input(2).get_partial_shape().rank().is_static())
{
result_shape = PartialShape::dynamic(input(2).get_partial_shape()[0]);
}
else
{
result_shape = PartialShape::dynamic();
}
set_output_size(1);
set_input_is_relevant_to_shape(2);
set_output_type(0, result_element_type, result_shape);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Generates a tensor populated with random values of a uniform distribution.
class RandomUniform : public op::Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"RandomUniform", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an uninitialized RandomUniform node.
RandomUniform() = default;
/// \brief Constructs a RandomUniform node.
/// \param min_value Output producing the minimum value (inclusive) for the random
/// uniform distribution. Must return a scalar of floating point type,
/// and the type must match that of `max_value`.
/// \param max_value Output producing the maximum value (inclusive) for the random
/// uniform distribution. Must return a scalar of floating point type,
/// and the type must match that of `min_value`.
/// \param result_shape Output producing the shape of the output tensor. Must return a
/// vector of type `element::i64`.
/// \param use_fixed_seed Output producing a boolean scalar Flag indicating whether to
/// use the value supplied in `fixed_seed` to re-seed the random
/// number generator at this iteration. Note that whenever
/// `use_fixed_seed` is `true`, the same values will be generated
/// in the output tensor. This flag is primarily used for
/// debugging. If `use_fixed_seed` is `false`, the value in
/// `fixed_seed` is ignored.
/// \param fixed_seed Fixed seed value to be supplied to the random number generator if
/// `use_fixed_seed` is `true`. If `use_fixed_seed` is `false`, this
/// value is ignored.
RandomUniform(const Output<Node>& min_value,
const Output<Node>& max_value,
const Output<Node>& result_shape,
const Output<Node>& use_fixed_seed,
uint64_t fixed_seed);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \brief Returns the fixed seed value to be supplied to the random number generator
/// if `use_fixed_seed` is `true`. If `use_fixed_seed` is `false`, this value is
/// ignored.
uint64_t get_fixed_seed() const { return m_fixed_seed; }
/// \brief Sets the fixed seed value to be supplied to the random number generator
/// if `use_fixed_seed` is `true`. If `use_fixed_seed` is `false`, this value is
/// ignored.
void set_fixed_seed(uint64_t fixed_seed) { m_fixed_seed = fixed_seed; }
// Internally, any implementation of RandomUniform will have state, since it is backed
// by a random number generator.
bool has_state() const override { return true; }
void validate_and_infer_types() override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& /* adjoints */,
const NodeVector& /* deltas */) override
{
}
uint64_t m_fixed_seed;
};
}
}
......@@ -128,6 +128,7 @@ NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op)
NGRAPH_OP(QuantizedDot, ngraph::op)
NGRAPH_OP(QuantizedDotBias, ngraph::op)
NGRAPH_OP(Recv, ngraph::op)
NGRAPH_OP(RandomUniform, ngraph::op)
NGRAPH_OP(Range, ngraph::op)
NGRAPH_OP(Relu, ngraph::op)
NGRAPH_OP(ReluBackprop, ngraph::op)
......
......@@ -17,7 +17,6 @@
#include "ngraph/runtime/cpu/op/dropout.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/dropout.hpp"
#include "ngraph/state/rng_state.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -17,7 +17,7 @@
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/reference/generate_mask.hpp"
#include "ngraph/state/rng_state.hpp"
#include "ngraph/state/bernoulli_rng_state.hpp"
using namespace std;
using namespace ngraph;
......@@ -50,7 +50,7 @@ namespace ngraph
auto seed_attr = gm->get_use_seed() ? gm->get_seed() : 0;
auto index = external_function->add_state(
ngraph::RNGState::create_rng_state(seed_attr, gm->get_probability()));
new ngraph::BernoulliRNGState(seed_attr, gm->get_probability()));
if (args[0].get_element_type() == element::f32)
{
......@@ -77,7 +77,7 @@ namespace ngraph
reference::generate_mask(
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
element_count,
static_cast<RNGState*>(ctx->states[index]),
static_cast<BernoulliRNGState*>(ctx->states[index]),
training);
}
else
......@@ -116,7 +116,7 @@ namespace ngraph
reference::generate_mask(
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
element_count,
static_cast<RNGState*>(ctx->states[index]),
static_cast<BernoulliRNGState*>(ctx->states[index]),
training);
}
else
......
......@@ -130,7 +130,7 @@
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp"
#include "ngraph/state/rng_state.hpp"
#include "ngraph/state/bernoulli_rng_state.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
......@@ -4153,9 +4153,9 @@ namespace ngraph
auto gm = static_cast<const ngraph::op::GenerateMask*>(node);
writer.block_begin();
auto index = external_function->add_state(
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability()));
writer << "auto state = static_cast<ngraph::RNGState*>(ctx->states[" << index
<< "]);\n";
new ngraph::BernoulliRNGState(gm->get_seed(), gm->get_probability()));
writer << "auto state = static_cast<ngraph::BernoulliRNGState*>(ctx->states["
<< index << "]);\n";
writer << "bool training = static_cast<bool>(" << args[0].get_name() << "[0]);\n";
writer << "bool use_seed = static_cast<bool>(" << args[2].get_name() << "[0]);\n";
......
......@@ -566,7 +566,7 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_
#include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/reference/xor.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/state/rng_state.hpp"
#include "ngraph/state/bernoulli_rng_state.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/util.hpp"
......
......@@ -19,7 +19,6 @@
#include <random>
#include "ngraph/shape.hpp"
#include "ngraph/state/rng_state.hpp"
namespace ngraph
{
......
......@@ -21,3 +21,10 @@ lrn_across_all_dims
lrn_across_nw
lrn_across_empty
lrn_6D_across_2_axes
# RandomUniform not supported in CPU backend
random_uniform_all_static_seed_unused
random_uniform_all_static_seed_used
random_uniform_seed_use_dynamic
random_uniform_all_static_range_dynamic
random_uniform_dynamic_shapes
......@@ -158,7 +158,7 @@
#include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/reference/xor.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/state/rng_state.hpp"
#include "ngraph/state/bernoulli_rng_state.hpp"
namespace ngraph
{
......@@ -199,7 +199,7 @@ private:
std::shared_ptr<Function> m_function;
std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes;
std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states;
std::unordered_map<const Node*, std::shared_ptr<ngraph::State>> m_states;
std::set<std::string> m_unsupported_op_name_list;
static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
......@@ -379,12 +379,12 @@ private:
{
const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node);
auto seed = use_seed ? gm->get_seed() : 0;
m_states[&node] = std::unique_ptr<ngraph::RNGState>(
ngraph::RNGState::create_rng_state(seed, gm->get_probability()));
m_states[&node] = std::unique_ptr<ngraph::State>(
new ngraph::BernoulliRNGState(seed, gm->get_probability()));
}
bool training = static_cast<bool>(args[0]->get_data_ptr<const T>()[0]);
auto state = m_states.at(&node).get();
auto state = static_cast<ngraph::BernoulliRNGState*>(m_states.at(&node).get());
size_t element_count = shape_size(node.get_output_shape(0));
if (!use_seed)
{
......
......@@ -46,6 +46,7 @@
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp"
......@@ -146,6 +147,7 @@
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/random_uniform.hpp"
#include "ngraph/runtime/reference/recv.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
......@@ -172,7 +174,8 @@
#include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/reference/xor.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/state/rng_state.hpp"
#include "ngraph/state/bernoulli_rng_state.hpp"
#include "ngraph/state/uniform_rng_state.hpp"
namespace ngraph
{
......@@ -195,7 +198,7 @@ public:
bool enable_performance_collection = false);
bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
const std::vector<std::shared_ptr<Tensor>>& intputs) override;
const std::vector<std::shared_ptr<Tensor>>& inputs) override;
virtual void save(std::ostream& output_stream) override;
......@@ -225,7 +228,7 @@ private:
std::shared_ptr<Function> m_function;
std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes;
std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states;
std::unordered_map<const Node*, std::shared_ptr<State>> m_states;
std::set<std::string> m_unsupported_op_name_list;
static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
......@@ -419,12 +422,12 @@ private:
{
const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node);
auto seed = use_seed ? gm->get_seed() : 0;
m_states[&node] = std::unique_ptr<ngraph::RNGState>(
ngraph::RNGState::create_rng_state(seed, gm->get_probability()));
m_states[&node] =
std::unique_ptr<State>(new BernoulliRNGState(seed, gm->get_probability()));
}
bool training = static_cast<bool>(args[0]->get_data_ptr<const T>()[0]);
auto state = m_states.at(&node).get();
auto state = static_cast<BernoulliRNGState*>(m_states.at(&node).get());
size_t element_count = shape_size(node.get_output_shape(0));
if (!use_seed)
{
......@@ -1449,6 +1452,38 @@ private:
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::RandomUniform:
{
const op::RandomUniform* ru = static_cast<const op::RandomUniform*>(&node);
T min_val = args[0]->get_data_ptr<const T>()[0];
T max_val = args[1]->get_data_ptr<const T>()[0];
// In INTERPRETER we can ignore arg 2 (output_shape) for now because we only work on
// static output shapes anyway.
bool use_fixed_seed = static_cast<bool>(args[3]->get_data_ptr<const char>()[0]);
if (m_states.count(&node) == 0)
{
m_states[&node] = std::unique_ptr<UniformRNGState>(new UniformRNGState());
}
auto state = static_cast<UniformRNGState*>(m_states.at(&node).get());
size_t element_count = shape_size(node.get_output_shape(0));
if (!use_fixed_seed)
{
reference::random_uniform<T>(
out[0]->get_data_ptr<T>(), min_val, max_val, element_count, state);
}
else
{
reference::random_uniform_with_fixed_seed<T>(out[0]->get_data_ptr<T>(),
min_val,
max_val,
element_count,
ru->get_fixed_seed());
}
break;
}
case OP_TYPEID::Range:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
......
......@@ -272,3 +272,10 @@ lrn_across_all_dims
lrn_across_nw
lrn_across_empty
lrn_6D_across_2_axes
# RandomUniform not supported in PlaidML backend
random_uniform_all_static_seed_unused
random_uniform_all_static_seed_used
random_uniform_seed_use_dynamic
random_uniform_all_static_range_dynamic
random_uniform_dynamic_shapes
......@@ -18,7 +18,7 @@
#include <random>
#include "ngraph/state/rng_state.hpp"
#include "ngraph/state/bernoulli_rng_state.hpp"
namespace ngraph
{
......@@ -27,7 +27,10 @@ namespace ngraph
namespace reference
{
template <typename T>
void generate_mask(T* out, size_t count, ngraph::RNGState* rng_state, bool training)
void generate_mask(T* out,
size_t count,
ngraph::BernoulliRNGState* rng_state,
bool training)
{
auto& gen = rng_state->get_generator();
auto& bd = rng_state->get_distribution();
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <random>
#include "ngraph/state/uniform_rng_state.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void random_uniform(
T* out, T min_val, T max_val, size_t count, ngraph::UniformRNGState* rng_state)
{
auto& gen = rng_state->get_generator();
auto& bd = rng_state->get_distribution();
for (size_t i = 0; i < count; i++)
{
out[i] = static_cast<T>(bd(gen)) * (max_val - min_val) + min_val;
}
}
template <typename T>
void random_uniform_with_fixed_seed(
T* out, T min_val, T max_val, size_t count, size_t fixed_seed)
{
ngraph::UniformRNGState rng_state(fixed_seed);
random_uniform(out, min_val, max_val, count, &rng_state);
}
}
}
}
......@@ -62,6 +62,7 @@
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp"
......@@ -1708,6 +1709,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Recv>(args[0], src_id);
break;
}
case OP_TYPEID::RandomUniform:
{
auto fixed_seed = node_js.at("fixed_seed").get<uint64_t>();
node = make_shared<op::RandomUniform>(args[0], args[1], args[2], args[3], fixed_seed);
break;
}
case OP_TYPEID::Range:
{
node = make_shared<op::Range>(args[0], args[1], args[2]);
......@@ -2795,6 +2802,12 @@ json JSONSerializer::serialize_node(const Node& n)
node["source_id"] = tmp->get_src_id();
break;
}
case OP_TYPEID::RandomUniform:
{
auto tmp = dynamic_cast<const op::RandomUniform*>(&n);
node["fixed_seed"] = tmp->get_fixed_seed();
break;
}
case OP_TYPEID::Range: { break;
}
case OP_TYPEID::Relu: { break;
......
......@@ -16,16 +16,16 @@
#include <random>
#include "bernoulli_rng_state.hpp"
#include "except.hpp"
#include "rng_state.hpp"
using namespace std;
using namespace ngraph;
void ngraph::RNGState::activate()
void ngraph::BernoulliRNGState::activate()
{
}
void ngraph::RNGState::deactivate()
void ngraph::BernoulliRNGState::deactivate()
{
}
......@@ -24,17 +24,10 @@
namespace ngraph
{
// can be based on TensorSate to cache values instead of just caching seed
class RNGState : public State
class BernoulliRNGState : public State
{
public:
static RNGState* create_rng_state(unsigned int seed, double probability)
{
auto rng = new RNGState(seed, probability);
return rng;
}
RNGState(unsigned int seed, double probability)
BernoulliRNGState(unsigned int seed, double probability)
: State()
, m_generator(seed)
, m_distribution(probability)
......@@ -42,7 +35,7 @@ namespace ngraph
}
virtual void activate() override;
virtual void deactivate() override;
virtual ~RNGState() override {}
virtual ~BernoulliRNGState() override {}
std::mt19937& get_generator() { return m_generator; }
std::bernoulli_distribution& get_distribution() { return m_distribution; }
protected:
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/state/uniform_rng_state.hpp"
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <memory>
#include <random>
#include "state.hpp"
namespace ngraph
{
class UniformRNGState : public State
{
public:
UniformRNGState(std::mt19937::result_type seed)
: State()
, m_generator(std::mt19937::result_type(seed))
, m_distribution()
{
}
UniformRNGState()
: State()
, m_generator(std::random_device()())
, m_distribution()
{
}
virtual void activate() override {}
virtual void deactivate() override {}
virtual ~UniformRNGState() override {}
std::mt19937& get_generator() { return m_generator; }
std::uniform_real_distribution<double>& get_distribution() { return m_distribution; }
private:
std::mt19937 m_generator;
std::uniform_real_distribution<double> m_distribution;
};
}
......@@ -133,6 +133,7 @@ set(SRC
type_prop/quantize.cpp
type_prop/quantized_convolution.cpp
type_prop/quantized_dot.cpp
type_prop/random_uniform.cpp
type_prop/range.cpp
type_prop/replace_slice.cpp
type_prop/reshape.cpp
......@@ -289,6 +290,7 @@ set(MULTI_TEST_SRC
backend/quantize_dequantize.in.cpp
backend/quantized_convolution.in.cpp
backend/quantized_dot.in.cpp
backend/random_uniform.in.cpp
backend/range.in.cpp
backend/relu.in.cpp
backend/replace_slice.in.cpp
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment