Commit af2b137b authored by gaurides's avatar gaurides Committed by Scott Cyphers

Support for RandomUniformOp in CPU (#3621)

* Add RandomUniform op

* Missing files

* Make CPU compile again (RandomUniform tests still don't pass)

* Bop users of existing RNGState class and reorder some comment junk

* Try to bop RNGState in gcpu

* Add RandomUniform to CPU manifest

* Add RandomUniform to PlaidML manifest

* Add .rst for random_uniform

* Clean up junk in the .rst

* Change UniformRNGState to always use double internally

* Change weird test failure message

* RandomUniform implementation for CPU backend

* Add missing file

* Add codegen support

* Code cleanup

* Address review comments

* Fix CI failure

* Fixed CI error

* Fixed CI error
parent ca413b5b
......@@ -66,6 +66,7 @@ set(SRC
builder/max_pool.cpp
builder/min.cpp
builder/one_hot.cpp
builder/random_uniform.cpp
builder/relu.cpp
builder/pad.cpp
builder/product.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/reference/random_uniform.hpp"
#include "ngraph/state/uniform_rng_state.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <typename T>
CPUKernelFunctor prepare_functor(const Node* node,
const vector<TensorViewWrapper>& args,
const vector<TensorViewWrapper>& out,
CPU_ExternalFunction* external_function)
{
auto ru = static_cast<const ngraph::op::RandomUniform*>(node);
CPUKernelFunctor functor;
auto arg0_buffer_index =
external_function->get_buffer_index(args[0].get_name()); // min_val
auto arg1_buffer_index =
external_function->get_buffer_index(args[1].get_name()); // max_val
auto arg2_buffer_index =
external_function->get_buffer_index(args[2].get_name()); // output_shape
auto arg3_buffer_index =
external_function->get_buffer_index(args[3].get_name()); // use_fixed_seed
auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
size_t element_count = out[0].get_size();
auto index = external_function->add_state(new ngraph::UniformRNGState());
auto fixed_seed = ru->get_fixed_seed();
functor = [&,
index,
element_count,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
arg3_buffer_index,
out_buffer_index,
fixed_seed](CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) {
// TODO: get shape when required
T min_val = static_cast<T*>(ctx->buffer_data[arg0_buffer_index])[0];
T max_val = static_cast<T*>(ctx->buffer_data[arg1_buffer_index])[0];
bool use_fixed_seed = static_cast<bool>(
static_cast<char*>(ctx->buffer_data[arg3_buffer_index])[0]);
if (!use_fixed_seed)
{
reference::random_uniform<T>(
static_cast<T*>(ctx->buffer_data[out_buffer_index]),
min_val,
max_val,
element_count,
static_cast<UniformRNGState*>(ctx->states[index]));
}
else
{
reference::random_uniform_with_fixed_seed<T>(
static_cast<T*>(ctx->buffer_data[out_buffer_index]),
min_val,
max_val,
element_count,
fixed_seed);
}
};
return functor;
}
template <>
void Builder::BUILDER_DECL(ngraph::op::RandomUniform)
{
auto& functors = external_function->get_functors();
CPUKernelFunctor functor;
if (args[2].get_element_type() != element::i64)
{
throw ngraph_error("Unsupported index 2 element type");
}
auto element_type = args[0].get_element_type();
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (element_type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in fold_constant_convert");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in fold_constant_convert");
break;
case element::Type_t::boolean:
functor = prepare_functor<char>(node, args, out, external_function);
break;
case element::Type_t::bf16:
functor = prepare_functor<bfloat16>(node, args, out, external_function);
break;
case element::Type_t::f16:
functor = prepare_functor<float16>(node, args, out, external_function);
break;
case element::Type_t::f32:
functor = prepare_functor<float>(node, args, out, external_function);
break;
case element::Type_t::f64:
functor = prepare_functor<double>(node, args, out, external_function);
break;
case element::Type_t::i8:
functor = prepare_functor<int8_t>(node, args, out, external_function);
break;
case element::Type_t::i16:
functor = prepare_functor<int16_t>(node, args, out, external_function);
break;
case element::Type_t::i32:
functor = prepare_functor<int32_t>(node, args, out, external_function);
break;
case element::Type_t::i64:
functor = prepare_functor<int64_t>(node, args, out, external_function);
break;
case element::Type_t::u8:
functor = prepare_functor<uint8_t>(node, args, out, external_function);
break;
case element::Type_t::u16:
functor = prepare_functor<uint16_t>(node, args, out, external_function);
break;
case element::Type_t::u32:
functor = prepare_functor<uint32_t>(node, args, out, external_function);
break;
case element::Type_t::u64:
functor = prepare_functor<uint64_t>(node, args, out, external_function);
break;
NGRAPH_UNREACHABLE("Unexpected switch case");
}
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
#endif
functors.emplace_back(functor);
}
void register_builders_random_uniform_cpp() { REGISTER_OP_BUILDER(RandomUniform); }
}
}
}
......@@ -57,6 +57,7 @@ namespace ngraph
register_builders_quantized_conv_cpp();
register_builders_quantized_dot_cpp();
register_builders_quantized_matmul_cpp();
register_builders_random_uniform_cpp();
register_builders_reduce_function_cpp();
register_builders_relu_cpp();
register_builders_replace_slice_cpp();
......
......@@ -56,6 +56,7 @@ namespace ngraph
void register_builders_quantized_conv_cpp();
void register_builders_quantized_dot_cpp();
void register_builders_quantized_matmul_cpp();
void register_builders_random_uniform_cpp();
void register_builders_reduce_function_cpp();
void register_builders_relu_cpp();
void register_builders_replace_slice_cpp();
......
......@@ -58,6 +58,7 @@
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
......@@ -131,6 +132,7 @@
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp"
#include "ngraph/state/bernoulli_rng_state.hpp"
#include "ngraph/state/uniform_rng_state.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
......@@ -4179,6 +4181,45 @@ namespace ngraph
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::RandomUniform)
{
auto ru = static_cast<const ngraph::op::RandomUniform*>(node);
if (args[2].get_element_type() != element::i64)
{
throw ngraph_error("Unsupported index 2 element type");
}
writer.block_begin();
auto index = external_function->add_state(new UniformRNGState());
auto fixed_seed = ru->get_fixed_seed();
writer << "auto state = static_cast<ngraph::RandomUniformRNGState*>(ctx->states["
<< index << "]);\n";
writer << "bool use_fixed_seed = static_cast<bool>(" << args[3].get_name()
<< "[0]);\n";
writer << "if (use_fixed_seed == false) \n";
writer << "{\n";
writer << " reference::random_uniform<" << args[0].get_type() << ">(\n";
writer << " " << out[0].get_name() << ",\n";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << out[0].get_size() << ",\n";
writer << " state);\n";
writer << "}\n";
writer << "else {\n";
writer << " reference::random_uniform_with_fixed_seed<" << args[0].get_type()
<< ">(\n";
writer << " " << out[0].get_name() << ",\n";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << out[0].get_size() << ",\n";
writer << " " << fixed_seed << ");\n";
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Dropout)
{
......
......@@ -150,6 +150,7 @@ namespace ngraph
class Quantize;
class QuantizedConcat;
class Tile;
class RandomUniform;
}
namespace runtime
{
......@@ -443,6 +444,8 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::QuantizedConcat);
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Tile);
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::RandomUniform);
}
}
}
......@@ -80,6 +80,7 @@
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment