Commit 5b994011 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Dropout for CPU (#1942)

* generate mask

* add codegen+dex

* states in context

* CPU dropout

* remove dead code

* remove dead code

* remove state.cpp

* change perms, add state.cpp

* address bobs feedback

* restore igpu unit-test manifest after a bad merge

* better error msgs

* throw on GPUs to keep a compiler happy

* address more feedback

* fix tests
parent 6c1ba614
......@@ -61,6 +61,7 @@ set (SRC
op/exp.cpp
op/floor.cpp
op/function_call.cpp
op/experimental/generate_mask.cpp
op/get_output_element.cpp
op/greater.cpp
op/greater_eq.cpp
......@@ -148,6 +149,7 @@ set (SRC
runtime/aligned_buffer.cpp
runtime/backend.cpp
runtime/backend_manager.cpp
state/rng_state.cpp
runtime/host_tensor.cpp
runtime/tensor.cpp
serializer.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/experimental/generate_mask.hpp"
using namespace std;
using namespace ngraph;
op::GenerateMask::GenerateMask(const std::shared_ptr<Node>& training,
const Shape& shape,
const element::Type& element_type,
unsigned int seed,
double prob)
: Op("GenerateMask", check_single_output_args({training}))
, m_shape(shape)
, m_element_type(element_type)
, m_seed(seed)
, m_probability(prob)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::GenerateMask::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<GenerateMask>(
new_args.at(0), m_shape, m_element_type, m_seed, m_probability);
}
void ngraph::op::GenerateMask::validate_and_infer_types()
{
NODE_VALIDATION_ASSERT(this, get_input_partial_shape(0).compatible(PartialShape{}))
<< "Training node should be a scalar flag indicating a mode";
NODE_VALIDATION_ASSERT(this, m_element_type.is_static())
<< "Output element type must not be dynamic.";
set_output_type(0, m_element_type, m_shape);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/state/rng_state.hpp"
namespace ngraph
{
namespace op
{
/// \brief GenerateMask
///
class GenerateMask : public op::Op
{
public:
/// \brief Constructs a GenerateMask node with a given shape, sed,
/// probability and training/inference mode
GenerateMask(const std::shared_ptr<Node>& training,
const Shape& shape,
const element::Type& element_type,
unsigned int seed,
double prob);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \brief Returns the probability of a trial generating 1 (i.e. an element being kept)
double get_probability() const { return m_probability; }
/// \brief Returns the seed value supplied to a random generator
unsigned int get_seed() const { return m_seed; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override
{
}
void validate_and_infer_types() override;
Shape m_shape;
element::Type m_element_type;
unsigned int m_seed;
double m_probability;
};
}
}
......@@ -76,6 +76,7 @@ NGRAPH_OP(Equal, ngraph::op)
NGRAPH_OP(Exp, ngraph::op)
NGRAPH_OP(Floor, ngraph::op)
NGRAPH_OP(FunctionCall, ngraph::op)
NGRAPH_OP(GenerateMask, ngraph::op)
NGRAPH_OP(GetOutputElement, ngraph::op)
NGRAPH_OP(Greater, ngraph::op)
NGRAPH_OP(GreaterEq, ngraph::op)
......
......@@ -69,6 +69,7 @@ set(SRC
builder/softmax.cpp
builder/sum.cpp
builder/topk.cpp
builder/state.cpp
builder/quantization.cpp
kernel/eigen_thread_pool.cpp
kernel/pad.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/reference/generate_mask.hpp"
#include "ngraph/state/rng_state.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::GenerateMask)
{
auto& functors = external_function->get_functors();
auto gm = static_cast<const ngraph::op::GenerateMask*>(node);
function<void(CPURuntimeContext*)> functor;
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
size_t element_count = out[0].get_size();
auto index = external_function->add_state(
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability()));
if (args[0].get_element_type() == element::f32)
{
functor = [&, index, element_count](CPURuntimeContext* ctx) {
bool training = static_cast<bool>(static_cast<float*>(arg_tensor)[0]);
reference::generate_mask(static_cast<float*>(out_tensor),
element_count,
static_cast<RNGState*>(ctx->states[index]),
training);
};
}
else if (args[0].get_element_type() == element::f64)
{
functor = [&, index, element_count](CPURuntimeContext* ctx) {
bool training = static_cast<bool>(static_cast<double*>(arg_tensor)[0]);
reference::generate_mask(static_cast<double*>(out_tensor),
element_count,
static_cast<RNGState*>(ctx->states[index]),
training);
};
}
else
{
throw ngraph_error(std::string("Unsupported type") +
args[0].get_element_type().c_type_string() +
"for GenerateMask");
}
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(GenerateMask);
}
}
}
......@@ -130,6 +130,7 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context()
const auto& mkldnn_emitter = m_external_function->get_mkldnn_emitter();
ctx->mkldnn_primitives = mkldnn_emitter->get_mkldnn_primitives().data();
ctx->mkldnn_workspaces = mkldnn_emitter->get_mkldnn_workspaces().data();
ctx->states = m_external_function->m_states.data();
if (std::getenv("NGRAPH_CPU_USE_TBB") != nullptr)
{
......
......@@ -47,6 +47,7 @@
#include "ngraph/op/dot.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
......@@ -4749,6 +4750,23 @@ namespace ngraph
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::GenerateMask)
{
auto gm = static_cast<const ngraph::op::GenerateMask*>(node);
writer.block_begin();
auto index = external_function->add_state(
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability()));
writer << "auto state = static_cast<ngraph::RNGState*>(ctx->states[" << index
<< "]);\n";
writer << "bool training = static_cast<bool>(" << args[0].get_name() << "[0]);\n";
writer << "reference::generate_mask(";
writer << " " << out[0].get_name() << ",\n";
writer << " " << out[0].get_size() << ",\n";
writer << " state, training);\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Dequantize)
{
......
......@@ -63,6 +63,7 @@
#include "ngraph/op/dot.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
......@@ -193,6 +194,10 @@ runtime::cpu::CPU_ExternalFunction::CPU_ExternalFunction(
runtime::cpu::CPU_ExternalFunction::~CPU_ExternalFunction()
{
for (auto state : m_states)
{
delete state;
}
}
class StaticInitializers
......@@ -358,12 +363,12 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::runtime::cpu::op::LoopKernel),
&runtime::cpu::CPU_Emitter::emit<runtime::cpu::op::LoopKernel>},
{TI(ngraph::op::LRN), &runtime::cpu::CPU_Emitter::emit<ngraph::op::LRN>},
{TI(ngraph::op::GenerateMask), &runtime::cpu::CPU_Emitter::emit<ngraph::op::GenerateMask>},
{TI(ngraph::op::ConvolutionAdd), &runtime::cpu::CPU_Emitter::emit<op::ConvolutionAdd>},
{TI(ngraph::op::Quantize), &runtime::cpu::CPU_Emitter::emit<op::Quantize>},
{TI(ngraph::op::Dequantize), &runtime::cpu::CPU_Emitter::emit<op::Dequantize>},
{TI(ngraph::op::Quantize), &runtime::cpu::CPU_Emitter::emit<ngraph::op::Quantize>},
{TI(ngraph::op::Dequantize), &runtime::cpu::CPU_Emitter::emit<ngraph::op::Dequantize>},
{TI(ngraph::op::GroupConvolutionBias),
&runtime::cpu::CPU_Emitter::emit<op::GroupConvolutionBias>},
};
static void
......@@ -443,6 +448,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/reference/convolution.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/generate_mask.hpp"
#include "ngraph/runtime/reference/lrn.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/max_pool.hpp"
......@@ -466,6 +472,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/state/rng_state.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/util.hpp"
......
......@@ -47,6 +47,7 @@
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/mkldnn_emitter.hpp"
#include "ngraph/runtime/performance_counter.hpp"
#include "ngraph/state/state.hpp"
namespace ngraph
{
......@@ -117,6 +118,12 @@ namespace ngraph
return m_mkldnn_emitter;
}
size_t add_state(ngraph::State* state)
{
m_states.push_back(state);
return m_states.size() - 1;
}
const std::string& get_function_name() const { return m_function_name; }
const std::shared_ptr<ngraph::Function> get_function() { return m_function; }
// Temporary Memory Pool alignment
......@@ -174,6 +181,8 @@ namespace ngraph
#endif
std::vector<ngraph::State*> m_states;
private:
// Register passes that are common to codegen and DEX
void register_common_passes(ngraph::pass::Manager& pass_manager);
......
......@@ -37,6 +37,8 @@ namespace ngraph
{
class AlignedBuffer;
}
class State;
}
namespace ngraph
......@@ -61,6 +63,7 @@ namespace ngraph
tbb::flow::graph* G;
tbb::global_control* c;
tbb::task_scheduler_init* init;
State* const* states;
std::set<size_t> breakpoints;
size_t pc;
};
......
......@@ -54,6 +54,7 @@
#include "ngraph/op/dot.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/function_call.hpp"
#include "ngraph/op/get_output_element.hpp"
......@@ -657,6 +658,11 @@ void runtime::gpu::GPU_Emitter::emit_FunctionCall(EMIT_ARGS)
writer.block_end();
}
void runtime::gpu::GPU_Emitter::emit_GenerateMask(EMIT_ARGS)
{
throw ngraph_error("GenerateMask is not supported yet on NVIDIA GPU");
}
void runtime::gpu::GPU_Emitter::emit_GetOutputElement(EMIT_ARGS)
{
auto get_tuple_element = static_cast<const ngraph::op::GetOutputElement*>(node);
......
......@@ -10,6 +10,7 @@ concat_matrix_int64
divide_by_zero_int32
#int64 is not supprted by cuDNN
dot_matrix_vector_int64
generate_mask
#no mkldnn on GPU
#error throw is not the same on GPU, not supported yet
one_hot_scalar_fp_nonint_in_3
......
......@@ -228,6 +228,7 @@ function_name
fuse_max_with_constant_zero_input_as_relu
greater
greatereq
generate_mask
kahan_sum_3d_to_vector
kahan_sum_to_scalar
less
......
......@@ -432,6 +432,10 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
function_output_names.push_back(get_input_name(op));
break;
}
case OP_TYPEID::GenerateMask:
{
throw ngraph_error("GenerateMask isn't yet supported on integrated GPU");
}
case OP_TYPEID::GetOutputElement:
{
if (op->get_inputs().empty() || op->get_outputs().size() != 1)
......
......@@ -25,6 +25,7 @@ divide_by_zero_int32
dot_3d_multi_axis
dot_4d_5d_multi_axis
dot_4d_5d_multi_axis_more
generate_mask
function_call
max_pool_3d
numeric_double_inf
......
......@@ -31,6 +31,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
......@@ -81,6 +82,7 @@
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/generate_mask.hpp"
#include "ngraph/runtime/reference/greater.hpp"
#include "ngraph/runtime/reference/greater_eq.hpp"
#include "ngraph/runtime/reference/less.hpp"
......@@ -314,6 +316,11 @@ private:
avg_pool->get_include_padding_in_avg_computation());
break;
}
case OP_TYPEID::GenerateMask:
{
throw ngraph_error(
"GenerateMask is an experimental op that's only supported on CPU backend");
}
case OP_TYPEID::GetOutputElement:
{
const op::GetOutputElement* get_output_element =
......
......@@ -7,4 +7,5 @@ batchnorm_fprop_inference_b2c2h2w1
batchnorm_fprop_bprop
batchnorm_fprop_bprop_2step
computation_reuse
generate_mask
topk_int64
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <random>
#include "ngraph/state/rng_state.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void generate_mask(T* out, size_t count, ngraph::RNGState* rng_state, bool training)
{
auto& gen = rng_state->get_generator();
auto& bd = rng_state->get_distribution();
for (size_t i = 0; i < count; i++)
{
out[i] = training ? static_cast<T>(bd(gen)) : static_cast<T>(1);
}
}
}
}
}
......@@ -44,6 +44,7 @@
#include "ngraph/op/dot.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/function_call.hpp"
#include "ngraph/op/get_output_element.hpp"
......@@ -731,6 +732,17 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::FunctionCall>(f_ptr, args);
break;
}
case OP_TYPEID::GenerateMask:
{
auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
auto type = read_element_type(node_js.at("type"));
auto seed = node_js.at("seed").get<unsigned int>();
auto probability = node_js.at("probability").get<double>();
node =
make_shared<op::GenerateMask>(args[0], output_shape, type, seed, probability);
break;
}
case OP_TYPEID::GetOutputElement:
{
node = make_shared<op::GetOutputElement>(args[0], node_js.at("n").get<size_t>());
......@@ -1365,6 +1377,15 @@ static json write(const Node& n, bool binary_constant_data)
node["n"] = tmp->get_n();
break;
}
case OP_TYPEID::GenerateMask:
{
auto tmp = dynamic_cast<const op::GenerateMask*>(&n);
node["output_shape"] = tmp->get_shape();
node["type"] = write_element_type(tmp->get_element_type());
node["seed"] = tmp->get_seed();
node["probability"] = tmp->get_probability();
break;
}
case OP_TYPEID::Greater: { break;
}
case OP_TYPEID::GreaterEq: { break;
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <random>
#include "except.hpp"
#include "rng_state.hpp"
using namespace std;
using namespace ngraph;
void ngraph::RNGState::activate()
{
}
void ngraph::RNGState::deactivate()
{
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <memory>
#include <random>
#include "state.hpp"
namespace ngraph
{
//can be based on TensorSate to cache values instead of just caching seed
class RNGState : public State
{
public:
static RNGState* create_rng_state(unsigned int seed, double probability)
{
auto rng = new RNGState(seed, probability);
return rng;
}
RNGState(unsigned int seed, double probability)
: State()
, m_generator(seed)
, m_distribution(probability)
{
}
virtual void activate() override;
virtual void deactivate() override;
virtual ~RNGState() {}
std::mt19937& get_generator() { return m_generator; }
std::bernoulli_distribution& get_distribution() { return m_distribution; }
protected:
std::mt19937 m_generator;
std::bernoulli_distribution m_distribution;
};
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
namespace ngraph
{
class State
{
public:
//TODO: add name and id
State() {}
virtual void activate() = 0;
virtual void deactivate() = 0;
bool is_active() const { return m_is_active; }
void set_active(bool flag) { m_is_active = flag; }
virtual ~State() {}
protected:
bool m_is_active = false;
};
}
......@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <algorithm>
#include <cinttypes>
#include <cmath>
......@@ -26,7 +27,9 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/state/rng_state.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
......@@ -4764,6 +4767,36 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_sequence_n4d2c3h2w2)
EXPECT_EQ(read_vector<int>(result), expected);
}
NGRAPH_TEST(${BACKEND_NAME}, generate_mask)
{
Shape scalar{};
Shape result_shape{1, 128};
const unsigned int seed = 777;
auto training = op::Constant::create(element::f32, Shape{}, {1});
auto gen_mask = make_shared<op::GenerateMask>(training, result_shape, element::f32, seed, 0.5);
auto gen_mask2 = make_shared<op::GenerateMask>(training, result_shape, element::f32, seed, 0.5);
auto f = make_shared<Function>(NodeVector{gen_mask, gen_mask2}, op::ParameterVector{});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto is_not_zero_or_one = [](float num) { return num != 0.f && num != 1.f; };
auto result_tv1 = backend->create_tensor<float>(result_shape);
auto result_tv2 = backend->create_tensor<float>(result_shape);
backend->call_with_validate(f, {result_tv1, result_tv2}, {});
auto result1 = read_vector<float>(result_tv1);
auto result2 = read_vector<float>(result_tv2);
ASSERT_EQ(result1, result2);
ASSERT_FALSE(std::any_of(result1.begin(), result1.end(), is_not_zero_or_one));
backend->call_with_validate(f, {result_tv1, result_tv2}, {});
auto result1_2 = read_vector<float>(result_tv1);
auto result2_2 = read_vector<float>(result_tv2);
ASSERT_NE(result1, result1_2);
ASSERT_FALSE(std::any_of(result1_2.begin(), result1_2.end(), is_not_zero_or_one));
ASSERT_NE(result2, result2_2);
ASSERT_FALSE(std::any_of(result2_2.begin(), result2_2.end(), is_not_zero_or_one));
}
NGRAPH_TEST(${BACKEND_NAME}, quantize)
{
Shape input_shape{4, 3};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment