Commit 8c38db04 authored by gaurides's avatar gaurides Committed by Scott Cyphers

Fuse Dropout (#3006)

* Initial implementation

* Added test case

* Bug fix; Dropout with 2 outputs, WIP

* Fixed in unit-testl; WIP for model

* Nothing is working

* Revert "Nothing is working"

This reverts commit d3ff09bb7a0d0519ab70ac85f2e7f30721afea96.

* Fixed unit-test; fusion with 2 outputs

* Fix style check, file permissions

* Changed input arg to Node

* Fix order of declaration

* Improved performance

* some cleanup

* Fixed CI error

* Fixed review comments

* Fix CI error

* Remove unused variable

* Fix other CI errors

* Changed type

* Fix style check

* Add codegen code for Dropout

* addressed PR feedback; will add codegen support later

* Cleanup; change variable name

* Support for use_seed

* Add setter for use_seed

* Add setter for use_seed

* Fix CI error

* Make use_seed as arg

* Fix CI error

* Fix CI error
parent b38f8ce0
......@@ -45,6 +45,7 @@ set(SRC
builder/convert_layout.cpp
builder/convolution.cpp
builder/dot.cpp
builder/dropout.cpp
builder/embedding_lookup.cpp
builder/erf.cpp
builder/gather.cpp
......@@ -99,6 +100,7 @@ set(SRC
op/conv_relu.cpp
op/convert_layout.cpp
op/deconv.cpp
op/dropout.cpp
op/group_conv_bias.cpp
op/halide_op.cpp
op/leaky_relu.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/cpu/op/dropout.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/dropout.hpp"
#include "ngraph/state/rng_state.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::Dropout)
{
auto& functors = external_function->get_functors();
auto drop = static_cast<const ngraph::op::Dropout*>(node);
CPUKernelFunctor functor;
auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto out1_buffer_index = external_function->get_buffer_index(out[1].get_name());
size_t element_count = out[0].get_size();
bool use_seed = drop->get_use_seed();
double keep_prob = drop->get_keep_prob();
// Note: for performance optimization in addition to parallel RNG with multiple,
// threads, we create, initialize and advance each msr here in builder instead of
// in kernel. By doing so here, we saved 30% vs. kernel
// msr.discard() has the biggest impact on performance.
// But we need discard only when use_seed==true to generate same mask.
size_t nthr = ngraph::runtime::cpu::executor::GetCPUExecutor().get_num_cores();
size_t chunk_size = (element_count + nthr - 1) / nthr;
std::vector<std::minstd_rand> vmsr(nthr);
if (use_seed)
{
uint32_t seed = drop->get_seed();
for (size_t i = 0; i < nthr; i++)
{
std::minstd_rand msr;
msr.seed(seed);
msr.discard(i * chunk_size);
vmsr[i] = msr;
}
}
if (args[0].get_element_type() == element::f32)
{
functor = [&,
element_count,
arg_buffer_index,
arg1_buffer_index,
out0_buffer_index,
out1_buffer_index,
keep_prob,
vmsr,
use_seed](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
bool training = static_cast<bool>(
static_cast<float*>(ctx->buffer_data[arg1_buffer_index])[0]);
runtime::cpu::kernel::generate_dropout(
static_cast<float*>(ctx->buffer_data[arg_buffer_index]),
static_cast<float*>(ctx->buffer_data[out0_buffer_index]),
static_cast<float*>(ctx->buffer_data[out1_buffer_index]),
element_count,
training,
keep_prob,
vmsr,
use_seed);
};
}
else if (args[0].get_element_type() == element::f64)
{
functor = [&,
element_count,
arg_buffer_index,
arg1_buffer_index,
out0_buffer_index,
out1_buffer_index,
keep_prob,
vmsr,
use_seed](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
bool training = static_cast<bool>(
static_cast<double*>(ctx->buffer_data[arg1_buffer_index])[0]);
runtime::cpu::kernel::generate_dropout(
static_cast<double*>(ctx->buffer_data[arg_buffer_index]),
static_cast<double*>(ctx->buffer_data[out0_buffer_index]),
static_cast<double*>(ctx->buffer_data[out1_buffer_index]),
element_count,
training,
keep_prob,
vmsr,
use_seed);
};
}
else
{
throw ngraph_error(std::string("Unsupported type") +
args[0].get_element_type().c_type_string() + "for Dropout");
}
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(Dropout);
}
}
}
......@@ -122,6 +122,7 @@
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/dropout.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp"
......@@ -4002,6 +4003,12 @@ namespace ngraph
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Dropout)
{
throw ngraph_error("Not yet implemented");
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Dequantize)
{
......
......@@ -78,6 +78,7 @@ namespace ngraph
CPUExecutor::CPUExecutor(int num_thread_pools)
: m_num_thread_pools(num_thread_pools)
{
m_num_cores = GetNumCores();
for (int i = 0; i < num_thread_pools; i++)
{
int num_threads_per_pool;
......
......@@ -54,11 +54,13 @@ namespace ngraph
CPUExecutionContext* ectx,
bool use_tbb = false);
int get_num_thread_pools() { return m_num_thread_pools; }
int get_num_cores() { return m_num_cores; }
private:
std::vector<std::unique_ptr<Eigen::ThreadPool>> m_thread_pools;
std::vector<std::unique_ptr<Eigen::ThreadPoolDevice>> m_thread_pool_devices;
std::vector<tbb::task_arena> m_tbb_arenas;
int m_num_thread_pools;
int m_num_cores;
};
extern CPUExecutor& GetCPUExecutor();
......
......@@ -168,6 +168,7 @@
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/dropout.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp"
......@@ -438,6 +439,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::DeconvolutionBias),
&runtime::cpu::CPU_Emitter::emit<ngraph::op::DeconvolutionBias>},
{TI(ngraph::op::QuantizedConcat), &runtime::cpu::CPU_Emitter::emit<op::QuantizedConcat>},
{TI(ngraph::op::Dropout), &runtime::cpu::CPU_Emitter::emit<op::Dropout>},
{TI(ngraph::op::Tile), &runtime::cpu::CPU_Emitter::emit<op::Tile>},
};
......
......@@ -18,6 +18,8 @@
#include <cstddef>
#include <cstdint>
#include <random>
#include <vector>
// CBLAS types and wrappers
......@@ -259,6 +261,15 @@ namespace ngraph
const Shape& indices_shape,
const Shape& updates_shape,
int arena);
template <typename T>
void generate_dropout(T* input,
T* out0,
T* out1_mask,
size_t nelems,
bool training,
const double value,
const std::vector<std::minstd_rand>& vmsr);
}
}
}
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <random>
#include "ngraph/shape.hpp"
#include "ngraph/state/rng_state.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
// Note: this kernel is for doing upscale in train
template <typename T>
void generate_dropout(T* input,
T* out0,
T* out1_mask,
const size_t nelems,
const bool training,
const double keep_prob,
const std::vector<std::minstd_rand>& vmsr,
const bool use_seed)
{
if (training)
{
int32_t rnd_seed = rand();
double dropout_prob = 1 - keep_prob;
#ifdef _OPENMP
size_t nthr =
ngraph::runtime::cpu::executor::GetCPUExecutor().get_num_cores();
size_t chunk_size = (nelems + nthr - 1) / nthr;
#pragma omp parallel num_threads(nthr)
{
size_t tid = omp_get_thread_num();
#else
size_t chunk_size = nelems;
{
size_t tid = 0;
#endif
/* Note :
In this implementation of dropout, we are trying to be same as PDPD
native implementation (and other frameworks).
https://github.com/NervanaSystems/ngraph-paddle/blob/14d88829b386c9f7601788c5539c08326dcbe2fe/paddle/fluid/operators/dropout_op.h#L58-L78
So, if framework passes same seed, then we will get same mask.*/
std::minstd_rand msr;
if (use_seed)
{
msr = vmsr[tid];
}
else
{
msr.seed(rnd_seed + tid);
}
std::uniform_real_distribution<> gen(0, 1);
size_t idx_start = tid * chunk_size;
size_t idx_end = std::min(idx_start + chunk_size, nelems);
for (size_t idx = idx_start; idx < idx_end; ++idx)
{
if (static_cast<T>(gen(msr)) < dropout_prob)
{
out1_mask[idx] = 0;
out0[idx] = 0;
}
else
{
out1_mask[idx] = 1;
out0[idx] = input[idx] / static_cast<T>(keep_prob);
}
}
}
}
else
{
// this is inference, ideally it should be optimized earlier
for (size_t i = 0; i < nelems; i++)
{
out1_mask[i] = 1;
out0[i] = static_cast<T>(1);
}
}
}
}
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/cpu/op/dropout.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
op::Dropout::Dropout(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& gm_const,
const std::shared_ptr<Node>& use_seed,
const uint32_t seed,
const double keep_prob)
: Op("Dropout", check_single_output_args({input, gm_const, use_seed}))
, m_seed(seed)
, m_keep_prob(keep_prob)
{
constructor_validate_and_infer_types();
set_output_size(2);
set_output_type(0, get_input_element_type(0), input->get_shape());
set_output_type(1, get_input_element_type(0), input->get_shape());
}
shared_ptr<Node> op::Dropout::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Dropout>(
new_args.at(0), new_args.at(1), new_args.at(2), m_seed, m_keep_prob);
}
bool op::Dropout::get_use_seed() const
{
bool use_seed = false;
if (auto const_op = dynamic_pointer_cast<op::Constant>(get_argument(2)))
{
auto use_seed_ptr = static_cast<const int32_t*>(const_op->get_data_ptr());
use_seed = static_cast<const bool>(*use_seed_ptr);
}
return use_seed;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace op
{
class Dropout : public Op
{
public:
Dropout(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& gm_const,
const std::shared_ptr<Node>& use_seed,
const uint32_t seed,
const double keep_prob); // keep_prob = 1 - dropout_prob
bool get_use_seed() const;
uint32_t get_seed() const { return m_seed; }
double get_keep_prob() const { return m_keep_prob; }
void set_seed(uint32_t new_seed) { m_seed = new_seed; }
void set_keep_prob(double new_keep_prob) { m_keep_prob = new_keep_prob; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
private:
uint32_t m_seed;
double m_keep_prob;
};
}
}
......@@ -22,8 +22,12 @@
#include "cpu_fusion.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
......@@ -36,6 +40,7 @@
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_concat.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
......@@ -72,6 +77,7 @@
#include "ngraph/runtime/cpu/op/conv_add.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/dropout.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
......@@ -911,6 +917,71 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_add()
this->add_matcher(m, callback);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_dropout()
{
Shape shape{1, 1, 2, 2};
auto x = std::make_shared<pattern::op::Label>(element::f32, shape);
auto x_label = std::make_shared<pattern::op::Label>(x, nullptr, NodeVector{x});
uint32_t seed = 1234;
auto seed_label = std::make_shared<pattern::op::Label>(element::u32, Shape{0});
double value = 0.9;
auto value_const = ngraph::op::Constant::create(element::f32, Shape{1, 1, 2, 2}, {value});
auto value_label = std::make_shared<pattern::op::Label>(value_const);
auto const1 = ngraph::op::Constant::create(x->get_element_type(), Shape{}, {1});
auto const1_label = std::make_shared<pattern::op::Label>(const1);
bool use_seed = false;
auto use_seed_const = ngraph::op::Constant::create(element::i32, Shape{}, {use_seed});
auto use_seed_label = std::make_shared<pattern::op::Label>(use_seed_const);
auto genmask = std::make_shared<op::GenerateMask>(
const1_label, x->get_shape(), x->get_element_type(), seed, value, use_seed);
auto genmask_label =
std::make_shared<pattern::op::Label>(genmask, nullptr, NodeVector{genmask});
auto mult = std::make_shared<ngraph::op::Multiply>(genmask_label, x_label);
auto pdivide = std::make_shared<ngraph::op::Divide>(mult, value_label);
auto callback = [x, const1_label, seed_label, value_label, genmask_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_dropout against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto m_div = std::static_pointer_cast<ngraph::op::Divide>(m.get_match_root());
auto gm = std::static_pointer_cast<ngraph::op::GenerateMask>(pattern_map[genmask_label]);
if (!std::dynamic_pointer_cast<ngraph::op::Constant>(gm->get_argument(0)))
{
NGRAPH_DEBUG << "training argument to GenerateMask must be constant";
return false;
}
auto gm_value = gm->get_probability();
auto gm_seed = gm->get_seed();
auto training = gm->get_argument(0); //for training purpose this is always going to be 1
auto use_seed_arg = gm->get_argument(2); // this is the use_seed node
auto dropout_n = std::make_shared<ngraph::op::Dropout>(
pattern_map[x], training, use_seed_arg, gm_seed, gm_value);
auto goe1 = std::make_shared<ngraph::op::GetOutputElement>(dropout_n, 0);
ngraph::replace_node(m.get_match_root(), goe1);
auto goe2 = std::make_shared<ngraph::op::GetOutputElement>(dropout_n, 1);
ngraph::replace_node(pattern_map[genmask_label], goe2);
return true;
};
auto m = std::make_shared<pattern::Matcher>(pdivide, "CPUFusion.Dropout");
this->add_matcher(m, callback);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_add_relu()
{
Shape shape{2, 2, 1, 1};
......
......@@ -77,6 +77,7 @@ public:
construct_deconvolution_affine_folding();
construct_deconvolution_affine_folding_relu();
}
construct_dropout();
}
}
......@@ -105,6 +106,7 @@ private:
void construct_fuse_lstm_recurrent_state();
void construct_deconvolution_affine_folding();
void construct_deconvolution_affine_folding_relu();
void construct_dropout();
};
class CPU_BACKEND_API ngraph::runtime::cpu::pass::CPUQuantFusion : public ngraph::pass::GraphRewrite
......
......@@ -30,6 +30,7 @@
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/quantized_concat.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
......@@ -41,6 +42,7 @@
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sum.hpp"
......@@ -64,6 +66,7 @@
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/dropout.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp"
......@@ -2809,6 +2812,79 @@ TEST(cpu_fusion, fuse_bounded_relu_inter_vs_cpu)
check_bounded_relu(Shape{4, 3, 2}, 2.0f);
}
TEST(cpu_fusion, fuse_dropout)
{
auto make_function = [](Shape input_shape,
const uint32_t seed_val,
double one_minus_prob,
bool fuse,
bool use_seed) {
auto input = std::make_shared<op::Parameter>(element::f32, input_shape);
auto value = op::Constant::create(element::f32, input_shape, {one_minus_prob});
auto const1 = op::Constant::create(input->get_element_type(), Shape{}, {1});
auto gen_mask = std::make_shared<op::GenerateMask>(const1,
input->get_shape(),
input->get_element_type(),
seed_val,
one_minus_prob,
use_seed);
auto mult = std::make_shared<op::Multiply>(gen_mask, input);
auto goe = std::make_shared<op::GetOutputElement>(mult, 0);
auto pdivide = fuse ? std::make_shared<op::Divide>(mult, value)
: std::make_shared<op::Divide>(goe, value);
auto f = make_shared<Function>(NodeVector{pdivide, gen_mask}, ParameterVector{input});
return f;
};
uint32_t seed = rand();
auto fuse_func = make_function(Shape{2, 2, 256, 256}, seed, 0.9, true, true);
auto fuse_func2 = make_function(Shape{2, 2, 256, 256}, seed, 0.9, true, true);
auto nofuse_func = make_function(Shape{2, 2, 256, 256}, 1, 0.9, false, false);
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(fuse_func);
pass_manager.run_passes(nofuse_func);
ASSERT_EQ(count_ops_of_type<op::Dropout>(fuse_func), 1);
ASSERT_EQ(count_ops_of_type<op::GenerateMask>(fuse_func), 0);
ASSERT_EQ(count_ops_of_type<op::Dropout>(nofuse_func), 0);
}
auto fuse_func3 = make_function(Shape{2, 2, 256, 256}, seed, 0.9, true, false);
auto fuse_func4 = make_function(Shape{2, 2, 256, 256}, seed, 0.9, true, false);
{
test::Uniform<float> rng(1.0f, 100.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : fuse_func->get_parameters())
{
auto name = param->get_name();
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto fuse_results = execute(fuse_func, args, "CPU");
auto fuse_results2 = execute(fuse_func2, args, "CPU");
EXPECT_TRUE(test::all_close(fuse_results.at(0), fuse_results2.at(0)));
EXPECT_TRUE(test::all_close(fuse_results.at(1), fuse_results2.at(1)));
auto fuse_results3 = execute(fuse_func3, args, "CPU");
auto fuse_results4 = execute(fuse_func4, args, "CPU");
EXPECT_FALSE(test::all_close(fuse_results3.at(0), fuse_results4.at(0)));
EXPECT_FALSE(test::all_close(fuse_results3.at(1), fuse_results4.at(1)));
// Note: Since the RNG used in Dropout kernel is different than RNG used in GenerateMask
// kernel, we can't compare fuse_results and nofuse_results
}
}
TEST(cpu_fusion, fuse_leaky_relu)
{
auto make_function = [](Shape input_shape, vector<float> alpha_val) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment