Commit 3f017a1e authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Use cpu kernel for constant folding. (#2538)

* Use cpu kernel for constant folding.

* Add default empty map.

* Fix a bug.

* Add new files.

* Address PR feedback.

* Check constant folding map before checking type for unary and binary ops.

* Address PR feedback.

* Address PR feedback.

* Use all_close_f.

Add relu unit test.

Make changes for sqrt and pad.

* Fix a bug.
parent 9fea22b2
This diff is collapsed.
......@@ -17,6 +17,7 @@
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
......@@ -40,9 +41,10 @@ public:
QUANTIZE
};
ConstantFolding()
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
: GraphRewrite()
{
m_cfmap = cfmap;
construct_constant_reshape();
construct_constant_broadcast();
construct_constant_pad();
......@@ -54,9 +56,11 @@ public:
//this allows to specify the order in which matchers will be run
//and also allows to register the same matcher more than once
ConstantFolding(const std::vector<CFTransformations>& transformations)
ConstantFolding(const std::vector<CFTransformations>& transformations,
const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
: GraphRewrite()
{
m_cfmap = cfmap;
for (auto cft : transformations)
{
switch (cft)
......@@ -80,4 +84,6 @@ private:
void construct_constant_binary();
void construct_constant_quantize();
void construct_constant_dequantize();
ngraph::BuildNodeExecutorMap m_cfmap;
};
......@@ -29,22 +29,18 @@ namespace ngraph
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::Broadcast)
static void get_broadcast_kernel(
const ngraph::Node* node,
std::function<decltype(runtime::cpu::kernel::broadcast<float, 2>)>& kernel,
Shape& expanded_input_shape,
Shape& out_shape,
size_t& size)
{
auto& functors = external_function->get_functors();
auto broadcast = static_cast<const ngraph::op::Broadcast*>(node);
auto broadcast_axes = broadcast->get_broadcast_axes();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto arg_shape = args[0].get_shape();
auto out_shape = out[0].get_shape();
// TODO(jmenon): Shape transformations, rank reduction etc. needs to be general
// and not in any one builder. Move this to the Halide analysis phase.
auto arg_shape = broadcast->get_argument(0)->get_shape();
out_shape = broadcast->get_shape();
// Transform output shape - ex. [4, 1, 2, 2] -> [4, 1, 4]
// if we're not broadcasting along axes 2 and 3
......@@ -96,9 +92,7 @@ namespace ngraph
else
{
broadcast_axes.erase(i);
// TODO(jmenon): This needs to be rewritten
// when it gets moved to the analysis pass
// that doesn't use AxisSet
auto new_bcast_axes = AxisSet{};
for (auto axis : broadcast_axes)
{
......@@ -128,11 +122,7 @@ namespace ngraph
if (broadcast_axes.empty())
{
size_t size = out[0].get_size() * out[0].get_element_type().size();
auto functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
memcpy(out_tensor, arg_tensor, size);
};
functors.emplace_back(functor);
size = shape_size(out_shape) * broadcast->get_element_type().size();
return;
}
......@@ -146,7 +136,7 @@ namespace ngraph
// so expand as needed
// Ex. [2] -> [2, 1] for output shape [2, 4]
auto expanded_input_shape = Shape(out_rank, 1);
expanded_input_shape = Shape(out_rank, 1);
size_t i = 0;
for (size_t j = 0; j < out_rank; j++)
{
......@@ -160,17 +150,70 @@ namespace ngraph
}
}
SELECT_KERNEL_BY_RANK(kernel,
broadcast->get_input_element_type(0),
out_rank,
runtime::cpu::kernel::broadcast);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Broadcast)
{
std::function<decltype(runtime::cpu::kernel::broadcast<float, 2>)> kernel;
Shape expanded_input_shape, out_shape;
size_t size;
get_broadcast_kernel(node, kernel, expanded_input_shape, out_shape, size);
NodeExecutorTy functor;
if (kernel)
{
functor = [kernel, expanded_input_shape, out_shape](
const std::vector<void*> inputs, std::vector<void*> outputs) {
kernel(inputs[0], outputs[0], expanded_input_shape, out_shape, 0);
};
}
else
{
functor = [size](const std::vector<void*>& inputs,
std::vector<void*>& outputs) {
memcpy(outputs[0], inputs[0], size);
};
}
return functor;
}
REGISTER_CF_BUILDER(Broadcast);
SELECT_KERNEL_BY_RANK(
kernel, args[0].get_element_type(), out_rank, runtime::cpu::kernel::broadcast);
template <>
void Builder::BUILDER_DECL(ngraph::op::Broadcast)
{
auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto functor = [&, kernel, expanded_input_shape, out_shape](
std::function<decltype(runtime::cpu::kernel::broadcast<float, 2>)> kernel;
Shape expanded_input_shape, out_shape;
size_t size;
get_broadcast_kernel(node, kernel, expanded_input_shape, out_shape, size);
CPUKernelFunctor functor;
if (kernel)
{
functor = [&, kernel, expanded_input_shape, out_shape](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg_tensor, out_tensor, expanded_input_shape, out_shape, ectx->arena);
kernel(
arg_tensor, out_tensor, expanded_input_shape, out_shape, ectx->arena);
};
functors.emplace_back(functor);
}
else
{
functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
memcpy(out_tensor, arg_tensor, size);
};
functors.emplace_back(functor);
}
}
REGISTER_OP_BUILDER(Broadcast);
}
......
......@@ -50,7 +50,7 @@ namespace ngraph
auto padding_above = pad->get_padding_above();
auto pad_mode = pad->get_pad_mode();
if (pad->get_pad_mode() == ngraph::op::PadMode::CONSTANT)
if (pad_mode == ngraph::op::PadMode::CONSTANT)
{
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
......@@ -97,6 +97,64 @@ namespace ngraph
}
REGISTER_OP_BUILDER(Pad);
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Pad)
{
auto pad = static_cast<const ngraph::op::Pad*>(node);
auto arg_shape = pad->get_argument(0)->get_shape();
auto out_shape = pad->get_shape();
auto padding_below = pad->get_padding_below();
auto padding_above = pad->get_padding_above();
auto pad_mode = pad->get_pad_mode();
if (pad_mode == ngraph::op::PadMode::CONSTANT)
{
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
pad->get_input_element_type(0),
arg_shape.size(),
runtime::cpu::kernel::pad_and_slice);
auto functor = [kernel, arg_shape, out_shape, padding_below, padding_above](
const std::vector<void*>& inputs, std::vector<void*>& outputs) {
kernel(inputs[0],
outputs[0],
inputs[1],
arg_shape,
out_shape,
CoordinateDiff(padding_below.begin(), padding_below.end()),
CoordinateDiff(padding_above.begin(), padding_above.end()),
0);
};
return functor;
}
else
{
std::function<decltype(runtime::cpu::kernel::pad_ref<float>)> kernel;
SELECT_KERNEL(
kernel, pad->get_input_element_type(0), runtime::cpu::kernel::pad_ref);
auto functor =
[kernel, arg_shape, out_shape, padding_below, padding_above, pad_mode](
const std::vector<void*>& inputs, std::vector<void*>& outputs) {
kernel(inputs[0],
inputs[1],
outputs[0],
arg_shape,
out_shape,
padding_below,
padding_above,
pad_mode,
0);
};
return functor;
}
}
REGISTER_CF_BUILDER(Pad);
}
}
}
......@@ -31,15 +31,31 @@ namespace ngraph
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::Reshape)
static void get_reshape_kernel(
const ngraph::Node* node,
std::function<decltype(runtime::cpu::kernel::reshape_1d<float, 2>)>& kernel,
std::function<decltype(runtime::cpu::kernel::reshape_ref<float>)>& ref_kernel,
Shape& arg_shape,
Shape& result_shape,
AxisVector& input_order,
size_t& size,
bool& skip_reshape)
{
auto& functors = external_function->get_functors();
auto reshape = static_cast<const ngraph::op::Reshape*>(node);
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
arg_shape = reshape->get_argument(0)->get_shape();
auto arg_rank = arg_shape.size();
auto reshape = static_cast<const ngraph::op::Reshape*>(node);
result_shape = reshape->get_output_shape();
auto result_rank = result_shape.size();
auto& result_element_type = reshape->get_element_type();
input_order = reshape->get_input_order();
bool same_layout = is_sorted(input_order.begin(), input_order.end());
auto result_size = shape_size(result_shape);
size = result_size * result_element_type.size();
auto can_skip_reshape = [&]() {
if (!reshape->get_is_transpose())
......@@ -56,41 +72,15 @@ namespace ngraph
if (can_skip_reshape())
{
size_t size = out[0].get_size() * out[0].get_element_type().size();
auto functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
if (out_tensor != arg_tensor)
{
memcpy(out_tensor, arg_tensor, size);
}
};
functors.emplace_back(functor);
skip_reshape = true;
return;
}
auto arg_shape = args[0].get_shape();
auto arg_rank = arg_shape.size();
auto result_shape = out[0].get_shape();
auto result_rank = result_shape.size();
auto& result_element_type = out[0].get_element_type();
auto input_order = reshape->get_input_order();
bool same_layout = is_sorted(input_order.begin(), input_order.end());
auto result_size = shape_size(result_shape);
if (same_layout || result_size < 2)
{
size_t size = out[0].get_size() * out[0].get_element_type().size();
auto functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
memcpy(out_tensor, arg_tensor, size);
};
functors.emplace_back(functor);
return;
}
std::function<decltype(runtime::cpu::kernel::reshape_1d<float, 2>)> kernel;
if (arg_rank == 1)
{
SELECT_KERNEL_BY_RANK(
......@@ -113,29 +103,128 @@ namespace ngraph
}
else
{
std::function<decltype(runtime::cpu::kernel::reshape_ref<float>)> ref_kernel;
SELECT_KERNEL(
ref_kernel, result_element_type, runtime::cpu::kernel::reshape_ref);
}
}
auto functor = [&, ref_kernel, arg_shape, input_order, result_shape](
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Reshape)
{
std::function<decltype(runtime::cpu::kernel::reshape_1d<float, 2>)> kernel;
std::function<decltype(runtime::cpu::kernel::reshape_ref<float>)> ref_kernel;
Shape arg_shape, result_shape;
AxisVector input_order;
size_t size;
bool skip_reshape = false;
get_reshape_kernel(node,
kernel,
ref_kernel,
arg_shape,
result_shape,
input_order,
size,
skip_reshape);
NodeExecutorTy functor;
if (kernel)
{
functor = [kernel, arg_shape, input_order, result_shape](
const std::vector<void*>& inputs, std::vector<void*>& outputs) {
kernel(inputs[0], outputs[0], arg_shape, input_order, result_shape, 0);
};
}
else if (ref_kernel)
{
functor = [ref_kernel, arg_shape, input_order, result_shape](
std::vector<void*> inputs, std::vector<void*> outputs) {
ref_kernel(inputs[0], outputs[0], arg_shape, input_order, result_shape, 0);
};
}
else if (skip_reshape)
{
functor = [size](const std::vector<void*>& inputs,
std::vector<void*>& outputs) {
if (inputs[0] != outputs[0])
{
memcpy(outputs[0], inputs[0], size);
}
};
}
else
{
functor = [size](const std::vector<void*>& inputs,
std::vector<void*>& outputs) {
memcpy(outputs[0], inputs[0], size);
};
}
return functor;
}
REGISTER_CF_BUILDER(Reshape);
template <>
void Builder::BUILDER_DECL(ngraph::op::Reshape)
{
auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
std::function<decltype(runtime::cpu::kernel::reshape_1d<float, 2>)> kernel;
std::function<decltype(runtime::cpu::kernel::reshape_ref<float>)> ref_kernel;
Shape arg_shape, result_shape;
AxisVector input_order;
size_t size;
bool skip_reshape = false;
get_reshape_kernel(node,
kernel,
ref_kernel,
arg_shape,
result_shape,
input_order,
size,
skip_reshape);
CPUKernelFunctor functor;
if (kernel)
{
functor = [&, kernel, arg_shape, input_order, result_shape](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ref_kernel(arg_tensor,
kernel(arg_tensor,
out_tensor,
arg_shape,
input_order,
result_shape,
ectx->arena);
};
functors.emplace_back(functor);
return;
}
auto functor = [&, kernel, arg_shape, input_order, result_shape](
else if (ref_kernel)
{
functor = [&, ref_kernel, arg_shape, input_order, result_shape](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(
arg_tensor, out_tensor, arg_shape, input_order, result_shape, ectx->arena);
ref_kernel(arg_tensor,
out_tensor,
arg_shape,
input_order,
result_shape,
ectx->arena);
};
}
else if (skip_reshape)
{
functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
if (out_tensor != arg_tensor)
{
memcpy(out_tensor, arg_tensor, size);
}
};
}
else
{
functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
memcpy(out_tensor, arg_tensor, size);
};
}
functors.emplace_back(functor);
}
......
......@@ -26,6 +26,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
......@@ -53,6 +54,7 @@
#include "ngraph/op/or.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
......@@ -65,6 +67,7 @@
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/kernel/abs.hpp"
#include "ngraph/runtime/cpu/kernel/acos.hpp"
#include "ngraph/runtime/cpu/kernel/add.hpp"
#include "ngraph/runtime/cpu/kernel/and.hpp"
#include "ngraph/runtime/cpu/kernel/asin.hpp"
#include "ngraph/runtime/cpu/kernel/atan.hpp"
......@@ -89,6 +92,7 @@
#include "ngraph/runtime/cpu/kernel/not.hpp"
#include "ngraph/runtime/cpu/kernel/not_equal.hpp"
#include "ngraph/runtime/cpu/kernel/or.hpp"
#include "ngraph/runtime/cpu/kernel/relu.hpp"
#include "ngraph/runtime/cpu/kernel/result.hpp"
#include "ngraph/runtime/cpu/kernel/sign.hpp"
#include "ngraph/runtime/cpu/kernel/sin.hpp"
......@@ -365,6 +369,66 @@ namespace ngraph
functors.emplace_back(functor);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Add)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::add);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Subtract)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::subtract);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Multiply)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::multiply);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Divide)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::divide);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Minimum)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::minimum);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Maximum)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::maximum);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Abs)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::abs);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Negative)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::negative);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Relu)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::relu);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Sqrt)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::sqrt);
}
#define TI(x) type_index(typeid(x))
BuildOpMap& GetGlobalBuildDispatcher()
......@@ -379,6 +443,12 @@ namespace ngraph
return build_dispatcher;
}
BuildNodeExecutorMap& GetGlobalCFDispatcherCPU()
{
static BuildNodeExecutorMap build_cf_dispatcher_cpu{};
return build_cf_dispatcher_cpu;
}
REGISTER_OP_BUILDER(Constant);
REGISTER_OP_BUILDER(Result);
REGISTER_OP_BUILDER(Subtract);
......@@ -414,6 +484,17 @@ namespace ngraph
REGISTER_OP_BUILDER(Minimum);
REGISTER_OP_BUILDER(And);
REGISTER_OP_BUILDER(Or);
REGISTER_CF_BUILDER(Add);
REGISTER_CF_BUILDER(Subtract);
REGISTER_CF_BUILDER(Multiply);
REGISTER_CF_BUILDER(Divide);
REGISTER_CF_BUILDER(Minimum);
REGISTER_CF_BUILDER(Maximum);
REGISTER_CF_BUILDER(Abs);
REGISTER_CF_BUILDER(Negative);
REGISTER_CF_BUILDER(Relu);
REGISTER_CF_BUILDER(Sqrt);
}
}
}
......@@ -232,6 +232,32 @@
}; \
functors.emplace_back(functor);
#define BUILD_UNARY_ELEMWISE_CF_FUNCTOR(OP) \
std::function<void(void*, void*, size_t, int)> kernel; \
\
SELECT_KERNEL(kernel, node->get_input_element_type(0), OP); \
\
auto element_count = shape_size(node->get_shape()); \
\
auto functor = [&, kernel, element_count](const std::vector<void*>& inputs, \
std::vector<void*>& outputs) { \
kernel(inputs[0], outputs[0], element_count, 0); \
}; \
return functor;
#define BUILD_BINARY_ELEMWISE_CF_FUNCTOR(OP) \
std::function<void(void*, void*, void*, size_t, int)> kernel; \
\
SELECT_KERNEL(kernel, node->get_input_element_type(0), OP); \
\
auto element_count = shape_size(node->get_shape()); \
\
auto functor = [&, kernel, element_count](const std::vector<void*>& inputs, \
std::vector<void*>& outputs) { \
kernel(inputs[0], inputs[1], outputs[0], element_count, 0); \
}; \
return functor;
#define REGISTER_OP_BUILDER(OP) \
static struct __register_##OP##_builder \
{ \
......@@ -253,6 +279,29 @@
} \
} __register_##OP##_builder_instance;
#define BUILDER_CF_DECL(op_name) CFbuild<op_name>(const ngraph::Node* node)
#define REGISTER_CF_BUILDER(OP) \
static struct __register_##OP##_cf_builder \
{ \
__register_##OP##_cf_builder() \
{ \
GetGlobalCFDispatcherCPU().insert({type_index(typeid(ngraph::op::OP)), \
&runtime::cpu::Builder::CFbuild<ngraph::op::OP>}); \
} \
} __register_##OP##_cf_builder_instance;
#define REGISTER_CPU_CF_BUILDER(OP) \
static struct __register_##OP##_cf_builder \
{ \
__register_##OP##_cf_builder() \
{ \
GetGlobalCFDispatcherCPU().insert( \
{type_index(typeid(ngraph::runtime::cpu::op::OP)), \
&runtime::cpu::Builder::CFbuild<ngraph::runtime::cpu::op::OP>}); \
} \
} __register_##OP##_cf_builder_instance;
namespace ngraph
{
namespace runtime
......@@ -269,6 +318,9 @@ namespace ngraph
BuildOpMap& GetGlobalBuildDispatcher();
// build the map to use cpu kernel for node execution
BuildNodeExecutorMap& GetGlobalCFDispatcherCPU();
class Builder
{
public:
......@@ -282,6 +334,13 @@ namespace ngraph
"' in CPU builder");
}
template <typename OP>
static NodeExecutorTy CFbuild(const ngraph::Node* node)
{
throw unsupported_op("Unimplemented op '" + node->description() +
"' for constant folding in CPU builder");
}
static void nop(CPU_ExternalFunction* external_function,
const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args,
......
......@@ -1140,7 +1140,8 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
NodeVector nv_cwi; // We dont need CPUWorkspaceInsertion to return list of indices
REGISTER_KNOBBED_PASS_WITH_ARGS(CPUWorkspaceInsertion, true, runtime::cpu::pass, nv_cwi, false);
REGISTER_KNOBBED_PASS_WITH_ARGS(CPUAssignment, true, runtime::cpu::pass, this);
REGISTER_KNOBBED_PASS(ConstantFolding, false, ngraph::pass);
REGISTER_KNOBBED_PASS_WITH_ARGS(
ConstantFolding, true, ngraph::pass, GetGlobalCFDispatcherCPU());
REGISTER_KNOBBED_PASS_WITH_ARGS(CPULayout, true, runtime::cpu::pass, this);
REGISTER_KNOBBED_PASS_WITH_ARGS(
CommonSubexpressionElimination, true, ngraph::pass, runtime::cpu::get_cse_handlers_map());
......
......@@ -25,6 +25,9 @@
#include <memory>
#include <sstream>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include <vector>
#include "ngraph/axis_vector.hpp"
......@@ -214,12 +217,21 @@ namespace ngraph
* This utility takes forward-propogation and back-propagation functions
* and turns them into clone functions where the intermediate values of
* the forward prop are added to the output of fprop and the input of the bprop
* to avoid repeat calcualtions.
* to avoid repeat calculations.
* The last argument is the adjoints coming into the bprop function, the output
* bprop function will have these nodes as the first N input parameters
**/
FpropCache cache_fprop(std::shared_ptr<Function> fprop, std::shared_ptr<Function> bprop);
// NodeExecutors are used in compiler optimization passes like ConstantFolding to execute a node
// using the supplied input and output memory locations.
// A BuildNodeExecutor returns a backend-specific NodeExecutor for a given Node type
using NodeExecutorTy =
std::function<void(const std::vector<void*>& inputs, std::vector<void*>& outputs)>;
using BuildNodeExecutor = std::function<NodeExecutorTy(const ngraph::Node*)>;
using BuildNodeExecutorMap = std::unordered_map<std::type_index, BuildNodeExecutor>;
enum class CPUTensorRole
{
INPUT,
......
......@@ -31,9 +31,11 @@
#include "ngraph/op/erf.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
......@@ -949,6 +951,184 @@ TEST(cpu_test, rotated_pooling)
make_f(false, false), make_f(false, false), "INTERPRETER", "CPU"); // 5D MaxPool
}
TEST(cpu_test, constant_reshape)
{
Shape shape_in{2, 4};
Shape shape_out{2, 4, 1};
const vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
auto reshape = make_shared<op::Reshape>(constant, AxisVector{0, 1}, shape_out);
auto f = make_shared<Function>(reshape, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>(
ngraph::runtime::cpu::GetGlobalCFDispatcherCPU());
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
const vector<float> values_out = new_const->get_vector<float>();
EXPECT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(cpu_test, constant_reshape_permute)
{
Shape shape_in{2, 4};
Shape shape_out{4, 2};
vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7};
auto constant = make_shared<op::Constant>(element::f64, shape_in, values_in);
auto reshape = make_shared<op::Reshape>(constant, AxisVector{1, 0}, shape_out);
auto f = make_shared<Function>(reshape, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>(
ngraph::runtime::cpu::GetGlobalCFDispatcherCPU());
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
const vector<double> values_out = new_const->get_vector<double>();
const vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
EXPECT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(cpu_test, constant_broadcast)
{
Shape shape_in{2};
Shape shape_out{2, 4};
vector<int> values_in{0, 1};
auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
auto broadcast = make_shared<op::Broadcast>(constant, shape_out, AxisSet{1});
auto f = make_shared<Function>(broadcast, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>(
ngraph::runtime::cpu::GetGlobalCFDispatcherCPU());
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Broadcast>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> values_permute{0, 0, 0, 0, 1, 1, 1, 1};
ASSERT_EQ(values_permute, values_out);
}
TEST(cpu_test, constant_pad_exterior)
{
Shape shape_in{2};
vector<int> values_in{777, 888};
auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
auto pad_value = make_shared<op::Constant>(element::i32, Shape{}, vector<int>{111});
CoordinateDiff padding_below{1};
CoordinateDiff padding_above{2};
auto broadcast = make_shared<op::Pad>(constant, pad_value, padding_below, padding_above);
auto f = make_shared<Function>(broadcast, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>(
ngraph::runtime::cpu::GetGlobalCFDispatcherCPU());
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Pad>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> padded_values{111, 777, 888, 111, 111};
ASSERT_EQ(padded_values, values_out);
}
template <typename T>
static std::vector<T> get_result_constant(std::shared_ptr<Function> f, size_t pos)
{
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(pos)->get_argument(0));
return new_const->get_vector<T>();
}
TEST(cpu_test, constant_unary_binary)
{
Shape shape_in{4};
vector<int> values_a{1, 2, 3, 4};
vector<int> values_b{1, 2, 3, 4};
vector<int> values_c{-1, -1, -1, -1};
vector<int> values_d{1, 4, 9, 16};
vector<int> values_e{1, -2, -3, 4};
auto a = make_shared<op::Constant>(element::i32, shape_in, values_a);
auto b = make_shared<op::Constant>(element::i32, shape_in, values_b);
auto c = make_shared<op::Constant>(element::i32, shape_in, values_c);
auto d = make_shared<op::Constant>(element::i32, shape_in, values_d);
auto e = make_shared<op::Constant>(element::i32, shape_in, values_e);
auto add = a + b;
auto sub = a - b;
auto mul = a * b;
auto divn = a / b;
auto min = make_shared<op::Minimum>(c, a);
auto max = make_shared<op::Maximum>(a, c);
auto absn = make_shared<op::Abs>(c);
auto neg = make_shared<op::Negative>(c);
auto sqrt = make_shared<op::Sqrt>(d);
auto neg_sqrt = make_shared<op::Sqrt>(c);
auto relu = make_shared<op::Relu>(e);
auto f = make_shared<Function>(NodeVector{add, sub, mul, divn, min, max, absn, neg, sqrt, relu},
ParameterVector{});
auto f_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>(
ngraph::runtime::cpu::GetGlobalCFDispatcherCPU());
pass_manager.run_passes(f);
//expected values
vector<int> add_expected{2, 4, 6, 8};
vector<int> sub_expected{0, 0, 0, 0};
vector<int> mul_expected{1, 4, 9, 16};
vector<int> div_expected{1, 1, 1, 1};
vector<int> min_expected{-1, -1, -1, -1};
vector<int> max_expected{1, 2, 3, 4};
vector<int> abs_neg_expected{1, 1, 1, 1};
vector<int> sqrt_expected{1, 2, 3, 4};
vector<int> relu_expected{1, 0, 0, 4};
ASSERT_EQ(get_result_constant<int>(f, 0), add_expected);
ASSERT_EQ(get_result_constant<int>(f, 1), sub_expected);
ASSERT_EQ(get_result_constant<int>(f, 2), mul_expected);
ASSERT_EQ(get_result_constant<int>(f, 3), div_expected);
ASSERT_EQ(get_result_constant<int>(f, 4), min_expected);
ASSERT_EQ(get_result_constant<int>(f, 5), max_expected);
ASSERT_EQ(get_result_constant<int>(f, 6), abs_neg_expected);
ASSERT_EQ(get_result_constant<int>(f, 7), abs_neg_expected);
ASSERT_EQ(get_result_constant<int>(f, 8), sqrt_expected);
ASSERT_EQ(get_result_constant<int>(f, 9), relu_expected);
ASSERT_ANY_THROW(pass_manager.run_passes(f_error));
}
TEST(cpu_test, conv_test_winograd)
{
/* This test checks for the cpu specific graph pass handling for conv_winograd implementation.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment