Commit 978691b4 authored by Diego Caballero's avatar Diego Caballero Committed by nmostafa

[MLIR] Rename LoopKernel->ComputedKernel. Move it to experimental core ops (#12)

We want to use ComputedKernel for any target to delimit sub-graphs to be
compiled and executed with MLIR.
parent d9dd03ce
...@@ -168,6 +168,8 @@ set (SRC ...@@ -168,6 +168,8 @@ set (SRC
op/experimental/quantized_dot.hpp op/experimental/quantized_dot.hpp
op/experimental/quantized_dot_bias.cpp op/experimental/quantized_dot_bias.cpp
op/experimental/quantized_dot_bias.hpp op/experimental/quantized_dot_bias.hpp
op/experimental/compiled_kernel.cpp
op/experimental/compiled_kernel.hpp
op/experimental/transpose.cpp op/experimental/transpose.cpp
op/experimental/transpose.hpp op/experimental/transpose.hpp
op/experimental/layers/ctc_greedy_decoder.cpp op/experimental/layers/ctc_greedy_decoder.cpp
......
...@@ -14,15 +14,16 @@ ...@@ -14,15 +14,16 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "ngraph/runtime/cpu/op/loop_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
shared_ptr<Node> shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector& new_args) const
ngraph::runtime::cpu::op::LoopKernel::copy_with_new_args(const NodeVector& new_args) const
{ {
auto args = get_arguments(); auto args = get_arguments();
if (new_args.size() != args.size()) if (new_args.size() != args.size())
...@@ -56,13 +57,13 @@ shared_ptr<Node> ...@@ -56,13 +57,13 @@ shared_ptr<Node>
new_outputs.push_back(nm.at(o.get())); new_outputs.push_back(nm.at(o.get()));
} }
return std::make_shared<LoopKernel>(new_node_list, new_outputs, new_args); return std::make_shared<CompiledKernel>(new_node_list, new_outputs, new_args);
} }
ngraph::runtime::cpu::op::LoopKernel::LoopKernel(const NodeVector& node_list, ngraph::op::CompiledKernel::CompiledKernel(const NodeVector& node_list,
const NodeVector& outputs, const NodeVector& outputs,
const NodeVector& args) const NodeVector& args)
: Op("LoopKernel", check_single_output_args({args})) : Op("CompiledKernel", check_single_output_args({args}))
, m_node_list(node_list) , m_node_list(node_list)
, m_output_nodes(outputs) , m_output_nodes(outputs)
{ {
......
...@@ -21,18 +21,18 @@ ...@@ -21,18 +21,18 @@
namespace ngraph namespace ngraph
{ {
namespace runtime
{
namespace cpu
{
namespace op namespace op
{ {
/// \brief LoopKernel represents graphs consisting /// \brief CompiledKernel represents a sub-graph that can be compiled and executed
/// of arithmetic operations that can be executed in the same loop /// independently.
class LoopKernel : public ngraph::op::Op ///
/// This op can be used to delimit sub-graphs that with special compilation requirements
/// within a function. For example, we currently use it to delimit sub-graphs that will be
/// independently compiled and executed by MLIR backend.
class CompiledKernel : public ngraph::op::Op
{ {
public: public:
LoopKernel(const NodeVector& node_list, CompiledKernel(const NodeVector& node_list,
const NodeVector& outputs, const NodeVector& outputs,
const NodeVector& args); const NodeVector& args);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -45,6 +45,4 @@ namespace ngraph ...@@ -45,6 +45,4 @@ namespace ngraph
NodeVector m_output_nodes; NodeVector m_output_nodes;
}; };
} }
}
}
} }
...@@ -101,7 +101,6 @@ set(SRC ...@@ -101,7 +101,6 @@ set(SRC
op/group_conv_bias.cpp op/group_conv_bias.cpp
op/halide_op.cpp op/halide_op.cpp
op/leaky_relu.cpp op/leaky_relu.cpp
op/loop_kernel.cpp
op/lstm.cpp op/lstm.cpp
op/matmul_bias.cpp op/matmul_bias.cpp
op/max_pool_with_indices.cpp op/max_pool_with_indices.cpp
...@@ -111,10 +110,10 @@ set(SRC ...@@ -111,10 +110,10 @@ set(SRC
op/update_slice.cpp op/update_slice.cpp
pass/cpu_assignment.cpp pass/cpu_assignment.cpp
pass/cpu_collapse_dims.cpp pass/cpu_collapse_dims.cpp
pass/cpu_compiled_kernel_fusion.cpp
pass/cpu_fusion.cpp pass/cpu_fusion.cpp
pass/cpu_horizontal_fusion.cpp pass/cpu_horizontal_fusion.cpp
pass/cpu_layout.cpp pass/cpu_layout.cpp
pass/cpu_loop_kernel_fusion.cpp
pass/cpu_mat_fusion.cpp pass/cpu_mat_fusion.cpp
pass/cpu_memory_assignment.cpp pass/cpu_memory_assignment.cpp
pass/cpu_memory_optimization.cpp pass/cpu_memory_optimization.cpp
...@@ -137,8 +136,8 @@ endif() ...@@ -137,8 +136,8 @@ endif()
if (NGRAPH_HALIDE) if (NGRAPH_HALIDE)
set(SRC set(SRC
${SRC} ${SRC}
builder/compiled_kernel.cpp
builder/halide_op.cpp builder/halide_op.cpp
builder/loop_kernel.cpp
builder/halide_generators.cpp builder/halide_generators.cpp
pass/halide_subgraph_extraction.cpp pass/halide_subgraph_extraction.cpp
) )
......
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
#include "halide_generators.hpp" #include "halide_generators.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp" #include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp" #include "ngraph/runtime/cpu/op/compiled_kernel.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -49,10 +49,10 @@ namespace ngraph ...@@ -49,10 +49,10 @@ namespace ngraph
namespace cpu namespace cpu
{ {
template <> template <>
void Builder::BUILDER_DECL(ngraph::runtime::cpu::op::LoopKernel) void Builder::BUILDER_DECL(ngraph::op::CompiledKernel)
{ {
const ngraph::runtime::cpu::op::LoopKernel* hs = const ngraph::op::CompiledKernel* hs =
static_cast<const ngraph::runtime::cpu::op::LoopKernel*>(node); static_cast<const ngraph::op::CompiledKernel*>(node);
const auto& generators = ngraph::runtime::cpu::halide::get_halide_generators(); const auto& generators = ngraph::runtime::cpu::halide::get_halide_generators();
...@@ -99,7 +99,7 @@ namespace ngraph ...@@ -99,7 +99,7 @@ namespace ngraph
//a subgraph //a subgraph
if (op->get_outputs().size() > 1) if (op->get_outputs().size() > 1)
{ {
throw ngraph_error("no multi-output ops in a LoopKernel"); throw ngraph_error("no multi-output ops in a CompiledKernel");
} }
halide_functions[op->get_output_tensor_ptr()->get_name()] = halide_functions[op->get_output_tensor_ptr()->get_name()] =
generators.at(TI(*op))(inputs); generators.at(TI(*op))(inputs);
......
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/equal.hpp" #include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp" #include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/floor.hpp" #include "ngraph/op/floor.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
...@@ -105,7 +106,6 @@ ...@@ -105,7 +106,6 @@
#include "ngraph/runtime/cpu/mlir/compiler.hpp" #include "ngraph/runtime/cpu/mlir/compiler.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp" #include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp" #include "ngraph/runtime/cpu/op/halide_op.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -444,8 +444,8 @@ namespace ngraph ...@@ -444,8 +444,8 @@ namespace ngraph
{ {
static BuildOpMap build_dispatcher{ static BuildOpMap build_dispatcher{
{TI(ngraph::op::Parameter), &runtime::cpu::Builder::nop}, {TI(ngraph::op::Parameter), &runtime::cpu::Builder::nop},
{TI(ngraph::runtime::cpu::op::LoopKernel), {TI(ngraph::op::CompiledKernel),
&runtime::cpu::Builder::build<ngraph::runtime::cpu::op::LoopKernel>}, &runtime::cpu::Builder::build<ngraph::op::CompiledKernel>},
{TI(ngraph::runtime::cpu::op::HalideOp), {TI(ngraph::runtime::cpu::op::HalideOp),
&runtime::cpu::Builder::build<ngraph::runtime::cpu::op::HalideOp>}}; &runtime::cpu::Builder::build<ngraph::runtime::cpu::op::HalideOp>}};
......
...@@ -117,13 +117,13 @@ ...@@ -117,13 +117,13 @@
#include "ngraph/runtime/cpu/op/batch_mat_mul_transpose.hpp" #include "ngraph/runtime/cpu/op/batch_mat_mul_transpose.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp" #include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/bounded_relu.hpp" #include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/runtime/cpu/op/conv_add.hpp" #include "ngraph/runtime/cpu/op/conv_add.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp" #include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp" #include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/deconv.hpp" #include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp" #include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp" #include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp" #include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp" #include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp" #include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
...@@ -3818,7 +3818,7 @@ namespace ngraph ...@@ -3818,7 +3818,7 @@ namespace ngraph
std::function<std::string(const std::vector<std::string>&)>> std::function<std::string(const std::vector<std::string>&)>>
inline_emitters = initialize_inline_emitters(); inline_emitters = initialize_inline_emitters();
// GOEE doesn't see GOEs in subgraphs that are hidden inside LoopKernels // GOEE doesn't see GOEs in subgraphs that are hidden inside CompiledKernels
// we have to manually propagate the source output // we have to manually propagate the source output
static const ngraph::descriptor::Output* static const ngraph::descriptor::Output*
get_goe_input_output(ngraph::descriptor::Output* output) get_goe_input_output(ngraph::descriptor::Output* output)
...@@ -3833,22 +3833,22 @@ namespace ngraph ...@@ -3833,22 +3833,22 @@ namespace ngraph
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::runtime::cpu::op::LoopKernel) void CPU_Emitter::EMITTER_DECL(ngraph::op::CompiledKernel)
{ {
std::unordered_map<const ngraph::descriptor::Output*, std::string> std::unordered_map<const ngraph::descriptor::Output*, std::string>
loop_symbol_table; loop_symbol_table;
// pre-fill symbol table with inputs // pre-fill symbol table with inputs
const ngraph::runtime::cpu::op::LoopKernel* clk = const ngraph::op::CompiledKernel* ck =
static_cast<const ngraph::runtime::cpu::op::LoopKernel*>(node); static_cast<const ngraph::op::CompiledKernel*>(node);
NodeVector output_nodes = clk->get_kernel_outputs(); NodeVector output_nodes = ck->get_kernel_outputs();
NodeVector node_list = clk->get_node_list(); NodeVector node_list = ck->get_node_list();
for (size_t i = 0; i < args.size(); i++) for (size_t i = 0; i < args.size(); i++)
{ {
std::string sname = std::string(args[i].get_name()) + "[i]"; std::string sname = std::string(args[i].get_name()) + "[i]";
auto entry = std::make_pair(&clk->get_inputs().at(i).get_output(), sname); auto entry = std::make_pair(&ck->get_inputs().at(i).get_output(), sname);
loop_symbol_table.insert(entry); loop_symbol_table.insert(entry);
} }
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/op/abs.hpp" #include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/maximum.hpp" #include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp" #include "ngraph/op/minimum.hpp"
...@@ -31,8 +32,7 @@ ...@@ -31,8 +32,7 @@
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp" #include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp" #include "ngraph/runtime/cpu/pass/cpu_compiled_kernel_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_loop_kernel_fusion.hpp"
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
...@@ -49,10 +49,10 @@ struct LKGraph ...@@ -49,10 +49,10 @@ struct LKGraph
NodeVector m_nodes; NodeVector m_nodes;
}; };
class LoopKernelCollector class CompiledKernelCollector
{ {
public: public:
LoopKernelCollector(std::shared_ptr<Function> f, size_t min_nodes_to_fuse) CompiledKernelCollector(std::shared_ptr<Function> f, size_t min_nodes_to_fuse)
{ {
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
{ {
...@@ -70,13 +70,13 @@ public: ...@@ -70,13 +70,13 @@ public:
else else
{ {
auto smallest_head = m_heads.at(arg_from_fusible_group); auto smallest_head = m_heads.at(arg_from_fusible_group);
auto& lkgraph = m_graphs.at(smallest_head); auto& ckgraph = m_graphs.at(smallest_head);
lkgraph.m_nodes.push_back(n); ckgraph.m_nodes.push_back(n);
for (auto arg : n->get_arguments()) for (auto arg : n->get_arguments())
{ {
if (is_leaf(arg)) if (is_leaf(arg))
{ {
lkgraph.m_inputs.push_back(arg); ckgraph.m_inputs.push_back(arg);
} }
} }
m_heads.insert(std::make_pair(n, smallest_head)); m_heads.insert(std::make_pair(n, smallest_head));
...@@ -88,18 +88,18 @@ public: ...@@ -88,18 +88,18 @@ public:
prune_graphs(min_nodes_to_fuse); prune_graphs(min_nodes_to_fuse);
} }
const std::vector<std::shared_ptr<runtime::cpu::op::LoopKernel>> get_loop_kernels() const const std::vector<std::shared_ptr<op::CompiledKernel>> get_compiled_kernels() const
{ {
std::vector<std::shared_ptr<runtime::cpu::op::LoopKernel>> lks; std::vector<std::shared_ptr<op::CompiledKernel>> cks;
for (auto e : m_graphs) for (auto e : m_graphs)
{ {
auto& lkg = e.second; auto& ckg = e.second;
NodeVector member_outputs = ngraph::get_subgraph_outputs(lkg.m_nodes, NodeVector{}); NodeVector member_outputs = ngraph::get_subgraph_outputs(ckg.m_nodes, NodeVector{});
auto lk = std::make_shared<runtime::cpu::op::LoopKernel>( auto ck =
lkg.m_nodes, member_outputs, lkg.m_inputs); std::make_shared<op::CompiledKernel>(ckg.m_nodes, member_outputs, ckg.m_inputs);
lks.push_back(lk); cks.push_back(ck);
} }
return lks; return cks;
} }
private: private:
...@@ -172,20 +172,20 @@ private: ...@@ -172,20 +172,20 @@ private:
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>> m_heads; std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>> m_heads;
}; };
bool ngraph::runtime::cpu::pass::CPULoopKernelFusion::run_on_function( bool ngraph::runtime::cpu::pass::CPUCompiledKernelFusion::run_on_function(
std::shared_ptr<ngraph::Function> function) std::shared_ptr<ngraph::Function> function)
{ {
LoopKernelCollector lkc(function, m_min_kernel_size); CompiledKernelCollector ckc(function, m_min_kernel_size);
auto loop_kernels = lkc.get_loop_kernels(); auto compiled_kernels = ckc.get_compiled_kernels();
for (auto lk : loop_kernels) for (auto ck : compiled_kernels)
{ {
auto outputs = lk->get_kernel_outputs(); auto outputs = ck->get_kernel_outputs();
std::set<std::shared_ptr<Node>> lk_nodes_set(lk->get_node_list().begin(), std::set<std::shared_ptr<Node>> ck_nodes_set(ck->get_node_list().begin(),
lk->get_node_list().end()); ck->get_node_list().end());
for (size_t i = 0; i < outputs.size(); i++) for (size_t i = 0; i < outputs.size(); i++)
{ {
auto ith_goe = std::make_shared<ngraph::op::GetOutputElement>(lk, i); auto ith_goe = std::make_shared<ngraph::op::GetOutputElement>(ck, i);
auto& ith_output = ith_goe->get_outputs().at(0); auto& ith_output = ith_goe->get_outputs().at(0);
if (outputs.at(i)->get_outputs().size() > 1) if (outputs.at(i)->get_outputs().size() > 1)
...@@ -203,8 +203,8 @@ bool ngraph::runtime::cpu::pass::CPULoopKernelFusion::run_on_function( ...@@ -203,8 +203,8 @@ bool ngraph::runtime::cpu::pass::CPULoopKernelFusion::run_on_function(
for (auto input : inputs_copy) for (auto input : inputs_copy)
{ {
// this user is NOT internal to this loop kernel // this user is NOT internal to this loop kernel
// so it needs to be replaced with corresponding lk's GOE // so it needs to be replaced with corresponding ck's GOE
if (lk_nodes_set.count(input->get_node()) == 0) if (ck_nodes_set.count(input->get_node()) == 0)
{ {
input->replace_output(ith_output); input->replace_output(ith_output);
} }
...@@ -212,5 +212,5 @@ bool ngraph::runtime::cpu::pass::CPULoopKernelFusion::run_on_function( ...@@ -212,5 +212,5 @@ bool ngraph::runtime::cpu::pass::CPULoopKernelFusion::run_on_function(
} }
} }
return !loop_kernels.empty(); return !compiled_kernels.empty();
} }
...@@ -26,10 +26,10 @@ namespace ngraph ...@@ -26,10 +26,10 @@ namespace ngraph
{ {
namespace pass namespace pass
{ {
class CPULoopKernelFusion : public ngraph::pass::FunctionPass class CPUCompiledKernelFusion : public ngraph::pass::FunctionPass
{ {
public: public:
CPULoopKernelFusion(size_t min_kernel_size = 2) CPUCompiledKernelFusion(size_t min_kernel_size = 2)
: FunctionPass() : FunctionPass()
, m_min_kernel_size(min_kernel_size) , m_min_kernel_size(min_kernel_size)
{ {
......
...@@ -50,6 +50,7 @@ ...@@ -50,6 +50,7 @@
#include "ngraph/op/erf.hpp" #include "ngraph/op/erf.hpp"
#include "ngraph/op/exp.hpp" #include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp" #include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp" #include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp" #include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp" #include "ngraph/op/experimental/dyn_reshape.hpp"
...@@ -63,31 +64,9 @@ ...@@ -63,31 +64,9 @@
#include "ngraph/op/experimental/quantized_dot_bias.hpp" #include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp" #include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/experimental/shape_of.hpp" #include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/experimental/transpose.hpp" #include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp" #include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/fake_quantize.hpp"
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/prelu.hpp" #include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/split.hpp"
#include "ngraph/op/fused/squared_difference.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp" #include "ngraph/op/greater_eq.hpp"
...@@ -118,8 +97,6 @@ ...@@ -118,8 +97,6 @@
#include "ngraph/op/result.hpp" #include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp" #include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp" #include "ngraph/op/select.hpp"
#include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp" #include "ngraph/op/sign.hpp"
...@@ -143,14 +120,6 @@ using namespace std; ...@@ -143,14 +120,6 @@ using namespace std;
using json = nlohmann::json; using json = nlohmann::json;
using const_data_callback_t = shared_ptr<Node>(const string&, const element::Type&, const Shape&); using const_data_callback_t = shared_ptr<Node>(const string&, const element::Type&, const Shape&);
static bool s_serialize_output_shapes_enabled =
(std::getenv("NGRAPH_SERIALIZER_OUTPUT_SHAPES") != nullptr);
void ngraph::set_serialize_output_shapes(bool enable)
{
s_serialize_output_shapes_enabled = enable;
}
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this: // This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// Abs, // Abs,
// Acos, // Acos,
...@@ -238,7 +207,7 @@ static json write_partial_shape(const PartialShape& s) ...@@ -238,7 +207,7 @@ static json write_partial_shape(const PartialShape& s)
{ {
vals[i] = write_dimension(s[i]); vals[i] = write_dimension(s[i]);
} }
return move(vals); return vals;
} }
} }
...@@ -259,27 +228,6 @@ static PartialShape read_partial_shape(const json& j) ...@@ -259,27 +228,6 @@ static PartialShape read_partial_shape(const json& j)
} }
} }
static json write_auto_broadcast(const op::AutoBroadcastSpec& autob)
{
json j;
j["type"] = autob.m_type;
j["axis"] = autob.m_axis;
return j;
}
static op::AutoBroadcastSpec read_auto_broadcast(const json& j)
{
if (!j.is_object())
{
return op::AutoBroadcastSpec();
}
else
{
return op::AutoBroadcastSpec(static_cast<op::AutoBroadcastType>(j.at("type")),
j.at("axis").get<size_t>());
}
}
static json write_element_type(const ngraph::element::Type& n) static json write_element_type(const ngraph::element::Type& n)
{ {
json j; json j;
...@@ -328,12 +276,6 @@ void ngraph::serialize(const string& path, shared_ptr<ngraph::Function> func, si ...@@ -328,12 +276,6 @@ void ngraph::serialize(const string& path, shared_ptr<ngraph::Function> func, si
} }
void ngraph::serialize(ostream& out, shared_ptr<ngraph::Function> func, size_t indent) void ngraph::serialize(ostream& out, shared_ptr<ngraph::Function> func, size_t indent)
{
out << ::serialize(func, indent, false);
}
#if defined ENABLE_CPIO_FILE
static void serialize_to_cpio(ostream& out, shared_ptr<ngraph::Function> func, size_t indent)
{ {
string j = ::serialize(func, indent, true); string j = ::serialize(func, indent, true);
cpio::Writer writer(out); cpio::Writer writer(out);
...@@ -353,7 +295,6 @@ static void serialize_to_cpio(ostream& out, shared_ptr<ngraph::Function> func, s ...@@ -353,7 +295,6 @@ static void serialize_to_cpio(ostream& out, shared_ptr<ngraph::Function> func, s
true); true);
}); });
} }
#endif
static string serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data) static string serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data)
{ {
...@@ -487,18 +428,6 @@ static json write(const Function& f, bool binary_constant_data) ...@@ -487,18 +428,6 @@ static json write(const Function& f, bool binary_constant_data)
return function; return function;
} }
template <typename T>
T get_value(nlohmann::json js, const string& key)
{
T rc;
auto it = js.find(key);
if (it != js.end())
{
rc = it->get<T>();
}
return rc;
}
static shared_ptr<ngraph::Function> static shared_ptr<ngraph::Function>
read_function(const json& func_js, read_function(const json& func_js,
unordered_map<string, shared_ptr<Function>>& function_map, unordered_map<string, shared_ptr<Function>>& function_map,
...@@ -515,23 +444,28 @@ static shared_ptr<ngraph::Function> ...@@ -515,23 +444,28 @@ static shared_ptr<ngraph::Function>
try try
{ {
string node_name = node_js.at("name").get<string>(); string node_name = node_js.at("name").get<string>();
string friendly_name;
auto it = node_js.find("friendly_name");
if (it != node_js.end())
{
friendly_name = it->get<string>();
}
string node_op = node_js.at("op").get<string>(); string node_op = node_js.at("op").get<string>();
string friendly_name = get_value<string>(node_js, "friendly_name"); vector<string> node_inputs = node_js.at("inputs").get<vector<string>>();
vector<string> node_inputs = get_value<vector<string>>(node_js, "inputs"); vector<string> control_deps_inputs =
vector<string> control_deps_inputs = get_value<vector<string>>(node_js, "control_deps"); get_or_default<vector<string>>(node_js, "control_deps", vector<string>{});
vector<string> node_outputs = get_value<vector<string>>(node_js, "outputs"); vector<string> node_outputs = node_js.at("outputs").get<vector<string>>();
shared_ptr<Node> node; shared_ptr<Node> node;
vector<shared_ptr<Node>> args; vector<shared_ptr<Node>> args;
vector<shared_ptr<Node>> control_deps;
for (const string& name : node_inputs) for (const string& name : node_inputs)
{ {
args.push_back(node_map.at(name)); args.push_back(node_map.at(name));
} }
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch" #pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum" #pragma GCC diagnostic error "-Wswitch-enum"
// #pragma GCC diagnostic error "-Wimplicit-fallthrough" // #pragma GCC diagnostic error "-Wimplicit-fallthrough"
#endif
switch (get_typeid(node_op)) switch (get_typeid(node_op))
{ {
case OP_TYPEID::Abs: case OP_TYPEID::Abs:
...@@ -546,8 +480,7 @@ static shared_ptr<ngraph::Function> ...@@ -546,8 +480,7 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Add: case OP_TYPEID::Add:
{ {
node = node = make_shared<op::Add>(args[0], args[1]);
make_shared<op::Add>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::All: case OP_TYPEID::All:
...@@ -563,8 +496,7 @@ static shared_ptr<ngraph::Function> ...@@ -563,8 +496,7 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::And: case OP_TYPEID::And:
{ {
node = node = make_shared<op::And>(args[0], args[1]);
make_shared<op::And>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Any: case OP_TYPEID::Any:
...@@ -606,16 +538,12 @@ static shared_ptr<ngraph::Function> ...@@ -606,16 +538,12 @@ static shared_ptr<ngraph::Function>
auto padding_above = node_js.at("padding_above").get<vector<size_t>>(); auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto include_padding_in_avg_computation = auto include_padding_in_avg_computation =
node_js.at("include_padding_in_avg_computation").get<bool>(); node_js.at("include_padding_in_avg_computation").get<bool>();
op::PadType pad_type = node_js["pad_type"].empty()
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
node = make_shared<op::AvgPool>(args[0], node = make_shared<op::AvgPool>(args[0],
window_shape, window_shape,
window_movement_strides, window_movement_strides,
padding_below, padding_below,
padding_above, padding_above,
include_padding_in_avg_computation, include_padding_in_avg_computation);
pad_type);
break; break;
} }
case OP_TYPEID::AvgPoolBackprop: case OP_TYPEID::AvgPoolBackprop:
...@@ -637,12 +565,6 @@ static shared_ptr<ngraph::Function> ...@@ -637,12 +565,6 @@ static shared_ptr<ngraph::Function>
include_padding_in_avg_computation); include_padding_in_avg_computation);
break; break;
} }
case OP_TYPEID::BatchMatMul:
{
node = make_shared<op::BatchMatMul>(args[0], args[1]);
break;
}
case OP_TYPEID::BatchNormTraining: case OP_TYPEID::BatchNormTraining:
{ {
auto epsilon = node_js.at("eps").get<double>(); auto epsilon = node_js.at("eps").get<double>();
...@@ -689,13 +611,6 @@ static shared_ptr<ngraph::Function> ...@@ -689,13 +611,6 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Ceiling>(args[0]); node = make_shared<op::Ceiling>(args[0]);
break; break;
} }
case OP_TYPEID::Clamp:
{
const auto clamp_min = node_js.at("min").get<float>();
const auto clamp_max = node_js.at("max").get<float>();
node = make_shared<op::Clamp>(args[0], clamp_min, clamp_max);
break;
}
case OP_TYPEID::Concat: case OP_TYPEID::Concat:
{ {
auto axis = node_js.at("axis").get<size_t>(); auto axis = node_js.at("axis").get<size_t>();
...@@ -708,8 +623,16 @@ static shared_ptr<ngraph::Function> ...@@ -708,8 +623,16 @@ static shared_ptr<ngraph::Function>
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js; node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
auto element_type = read_element_type(type_node_js.at("element_type")); auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape"); auto shape = type_node_js.at("shape");
auto value = node_js.at("value").get<vector<string>>(); auto value_it = node_js.find("value");
if (value_it != node_js.end())
{
auto value = value_it->get<vector<string>>();
node = make_shared<op::Constant>(element_type, shape, value); node = make_shared<op::Constant>(element_type, shape, value);
}
else
{
node = const_data_callback(node_name, element_type, shape);
}
break; break;
} }
case OP_TYPEID::Convert: case OP_TYPEID::Convert:
...@@ -735,10 +658,6 @@ static shared_ptr<ngraph::Function> ...@@ -735,10 +658,6 @@ static shared_ptr<ngraph::Function>
data_dilation_strides_maybe = node_js["image_dilation_strides"]; data_dilation_strides_maybe = node_js["image_dilation_strides"];
} }
op::PadType pad_type = node_js["pad_type"].empty()
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
if (data_dilation_strides_maybe.empty()) if (data_dilation_strides_maybe.empty())
{ {
node = make_shared<op::Convolution>(args[0], node = make_shared<op::Convolution>(args[0],
...@@ -757,8 +676,7 @@ static shared_ptr<ngraph::Function> ...@@ -757,8 +676,7 @@ static shared_ptr<ngraph::Function>
window_dilation_strides, window_dilation_strides,
padding_below, padding_below,
padding_above, padding_above,
data_dilation_strides_maybe.get<std::vector<size_t>>(), data_dilation_strides_maybe.get<std::vector<size_t>>());
pad_type);
} }
break; break;
} }
...@@ -808,75 +726,6 @@ static shared_ptr<ngraph::Function> ...@@ -808,75 +726,6 @@ static shared_ptr<ngraph::Function>
data_dilation_strides_forward); data_dilation_strides_forward);
break; break;
} }
case OP_TYPEID::ConvolutionBias:
{
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto window_dilation_strides =
node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides =
node_js.at("data_dilation_strides").get<vector<size_t>>();
node = make_shared<op::ConvolutionBias>(args[0],
args[1],
args[2],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
break;
}
case OP_TYPEID::ConvolutionBiasAdd:
{
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto window_dilation_strides =
node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides =
node_js.at("data_dilation_strides").get<vector<size_t>>();
node = make_shared<op::ConvolutionBiasAdd>(args[0],
args[1],
args[2],
args[3],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
break;
}
case OP_TYPEID::ConvolutionBiasBackpropFiltersBias:
{
auto filters_shape = node_js.at("filters_shape").get<vector<size_t>>();
auto bias_shape = node_js.at("bias_shape").get<vector<size_t>>();
auto window_movement_strides_forward =
node_js.at("window_movement_strides_forward").get<vector<size_t>>();
auto window_dilation_strides_forward =
node_js.at("window_dilation_strides_forward").get<vector<size_t>>();
auto padding_below_forward =
node_js.at("padding_below_forward").get<vector<std::ptrdiff_t>>();
auto padding_above_forward =
node_js.at("padding_above_forward").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides_forward =
node_js.at("data_dilation_strides_forward").get<vector<size_t>>();
node = make_shared<op::ConvolutionBiasBackpropFiltersBias>(
args[0],
filters_shape,
bias_shape,
args[1],
window_movement_strides_forward,
window_dilation_strides_forward,
padding_below_forward,
padding_above_forward,
data_dilation_strides_forward);
break;
}
case OP_TYPEID::Cos: case OP_TYPEID::Cos:
{ {
node = make_shared<op::Cos>(args[0]); node = make_shared<op::Cos>(args[0]);
...@@ -887,12 +736,6 @@ static shared_ptr<ngraph::Function> ...@@ -887,12 +736,6 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Cosh>(args[0]); node = make_shared<op::Cosh>(args[0]);
break; break;
} }
case OP_TYPEID::DepthToSpace:
{
auto block_size = node_js.at("block_size").get<size_t>();
node = make_shared<op::DepthToSpace>(args[0], block_size);
break;
}
case OP_TYPEID::Dequantize: case OP_TYPEID::Dequantize:
{ {
auto type = read_element_type(node_js.at("type")); auto type = read_element_type(node_js.at("type"));
...@@ -902,8 +745,7 @@ static shared_ptr<ngraph::Function> ...@@ -902,8 +745,7 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Divide: case OP_TYPEID::Divide:
{ {
node = make_shared<op::Divide>( node = make_shared<op::Divide>(args[0], args[1]);
args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Dot: case OP_TYPEID::Dot:
...@@ -941,11 +783,6 @@ static shared_ptr<ngraph::Function> ...@@ -941,11 +783,6 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::DynSlice>(args[0], args[1], args[2], args[3]); node = make_shared<op::DynSlice>(args[0], args[1], args[2], args[3]);
break; break;
} }
case OP_TYPEID::Elu:
{
node = make_shared<op::Elu>(args[0], args[1]);
break;
}
case OP_TYPEID::EmbeddingLookup: case OP_TYPEID::EmbeddingLookup:
{ {
node = make_shared<op::EmbeddingLookup>(args[0], args[1]); node = make_shared<op::EmbeddingLookup>(args[0], args[1]);
...@@ -953,8 +790,7 @@ static shared_ptr<ngraph::Function> ...@@ -953,8 +790,7 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Equal: case OP_TYPEID::Equal:
{ {
node = node = make_shared<op::Equal>(args[0], args[1]);
make_shared<op::Equal>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Erf: case OP_TYPEID::Erf:
...@@ -967,39 +803,11 @@ static shared_ptr<ngraph::Function> ...@@ -967,39 +803,11 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Exp>(args[0]); node = make_shared<op::Exp>(args[0]);
break; break;
} }
case OP_TYPEID::FakeQuantize:
{
size_t levels = node_js.at("levels").get<size_t>();
node = make_shared<op::FakeQuantize>(
args[0], args[1], args[2], args[3], args[4], levels);
break;
}
case OP_TYPEID::Floor: case OP_TYPEID::Floor:
{ {
node = make_shared<op::Floor>(args[0]); node = make_shared<op::Floor>(args[0]);
break; break;
} }
case OP_TYPEID::Gather:
{
auto axis = node_js.at("axis").get<size_t>();
node = make_shared<op::Gather>(args[0], args[1], axis);
break;
}
case OP_TYPEID::GatherND:
{
node = make_shared<op::GatherND>(args[0], args[1]);
break;
}
case OP_TYPEID::Gemm:
{
auto alpha = node_js.at("alpha").get<double>();
auto beta = node_js.at("beta").get<double>();
auto transA = node_js.at("transA").get<bool>();
auto transB = node_js.at("transB").get<bool>();
node =
make_shared<op::Gemm>(args[0], args[1], args[2], alpha, beta, transA, transB);
break;
}
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
{ {
auto output_shape = node_js.at("output_shape").get<vector<size_t>>(); auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
...@@ -1018,71 +826,22 @@ static shared_ptr<ngraph::Function> ...@@ -1018,71 +826,22 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Greater: case OP_TYPEID::Greater:
{ {
node = make_shared<op::Greater>( node = make_shared<op::Greater>(args[0], args[1]);
args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::GreaterEq: case OP_TYPEID::GreaterEq:
{ {
node = make_shared<op::GreaterEq>( node = make_shared<op::GreaterEq>(args[0], args[1]);
args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::GRN:
{
auto bias = node_js.at("bias").get<float>();
node = make_shared<op::GRN>(args[0], bias);
break;
}
case OP_TYPEID::HardSigmoid:
{
auto alpha = node_js.at("alpha").get<float>();
auto beta = node_js.at("beta").get<float>();
node = make_shared<op::HardSigmoid>(args[0], alpha, beta);
break;
}
case OP_TYPEID::GroupConvolution:
{
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto window_dilation_strides =
node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides =
node_js.at("data_dilation_strides").get<vector<size_t>>();
auto groups = node_js.at("groups").get<size_t>();
op::PadType pad_type = node_js["pad_type"].empty()
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
node = make_shared<op::GroupConvolution>(args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
groups,
pad_type);
break;
}
case OP_TYPEID::LeakyRelu:
{
node = make_shared<op::LeakyRelu>(args[0], args[1]);
break; break;
} }
case OP_TYPEID::Less: case OP_TYPEID::Less:
{ {
node = node = make_shared<op::Less>(args[0], args[1]);
make_shared<op::Less>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::LessEq: case OP_TYPEID::LessEq:
{ {
node = make_shared<op::LessEq>( node = make_shared<op::LessEq>(args[0], args[1]);
args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Log: case OP_TYPEID::Log:
...@@ -1114,9 +873,6 @@ static shared_ptr<ngraph::Function> ...@@ -1114,9 +873,6 @@ static shared_ptr<ngraph::Function>
// omitted. // omitted.
auto padding_below_maybe = node_js["padding_below"]; auto padding_below_maybe = node_js["padding_below"];
auto padding_above_maybe = node_js["padding_above"]; auto padding_above_maybe = node_js["padding_above"];
op::PadType pad_type = node_js["pad_type"].empty()
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
if (padding_below_maybe.empty() && !padding_above_maybe.empty()) if (padding_below_maybe.empty() && !padding_above_maybe.empty())
{ {
throw runtime_error( throw runtime_error(
...@@ -1135,8 +891,7 @@ static shared_ptr<ngraph::Function> ...@@ -1135,8 +891,7 @@ static shared_ptr<ngraph::Function>
window_shape, window_shape,
window_movement_strides, window_movement_strides,
padding_below, padding_below,
padding_above, padding_above);
pad_type);
} }
else else
{ {
...@@ -1174,8 +929,7 @@ static shared_ptr<ngraph::Function> ...@@ -1174,8 +929,7 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Maximum: case OP_TYPEID::Maximum:
{ {
node = make_shared<op::Maximum>( node = make_shared<op::Maximum>(args[0], args[1]);
args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Min: case OP_TYPEID::Min:
...@@ -1186,22 +940,12 @@ static shared_ptr<ngraph::Function> ...@@ -1186,22 +940,12 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Minimum: case OP_TYPEID::Minimum:
{ {
node = make_shared<op::Minimum>( node = make_shared<op::Minimum>(args[0], args[1]);
args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Multiply: case OP_TYPEID::Multiply:
{ {
node = make_shared<op::Multiply>( node = make_shared<op::Multiply>(args[0], args[1]);
args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::MVN:
{
auto normalize_variance = node_js.at("normalize_variance").get<bool>();
auto across_channels = node_js.at("across_channels").get<bool>();
auto eps = node_js.at("eps").get<double>();
node = make_shared<op::MVN>(args[0], normalize_variance, across_channels, eps);
break; break;
} }
case OP_TYPEID::Negative: case OP_TYPEID::Negative:
...@@ -1209,19 +953,9 @@ static shared_ptr<ngraph::Function> ...@@ -1209,19 +953,9 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Negative>(args[0]); node = make_shared<op::Negative>(args[0]);
break; break;
} }
case OP_TYPEID::Normalize:
{
bool across_spatial = node_js.at("across_spatial").get<bool>();
bool channel_shared = node_js.at("channel_shared").get<bool>();
float eps = node_js.at("eps").get<float>();
node = make_shared<op::Normalize>(
args[0], args[1], across_spatial, channel_shared, eps);
break;
}
case OP_TYPEID::NotEqual: case OP_TYPEID::NotEqual:
{ {
node = make_shared<op::NotEqual>( node = make_shared<op::NotEqual>(args[0], args[1]);
args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Not: case OP_TYPEID::Not:
...@@ -1238,7 +972,7 @@ static shared_ptr<ngraph::Function> ...@@ -1238,7 +972,7 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Or: case OP_TYPEID::Or:
{ {
node = make_shared<op::Or>(args[0], args[1], read_auto_broadcast(node_js["autob"])); node = make_shared<op::Or>(args[0], args[1]);
break; break;
} }
case OP_TYPEID::Pad: case OP_TYPEID::Pad:
...@@ -1249,11 +983,15 @@ static shared_ptr<ngraph::Function> ...@@ -1249,11 +983,15 @@ static shared_ptr<ngraph::Function>
// This is a legacy field whose functionality is no longer supported. The new // This is a legacy field whose functionality is no longer supported. The new
// behavior is equivalent to interior padding of 0, so we will accept it under // behavior is equivalent to interior padding of 0, so we will accept it under
// those conditions. // those conditions.
auto padding_interior = get_value<vector<size_t>>(node_js, "padding_interior"); auto padding_interior_maybe = node_js.find("padding_interior");
NGRAPH_CHECK(std::all_of(padding_interior.begin(), if (padding_interior_maybe != node_js.end())
{
auto padding_interior = padding_interior_maybe->get<vector<size_t>>();
NGRAPH_ASSERT(std::all_of(padding_interior.begin(),
padding_interior.end(), padding_interior.end(),
[](size_t s) { return s == 0; }), [](size_t s) { return s == 0; }))
"Legacy padding_interior field must be zero everywhere."); << "Legacy padding_interior field must be zero everywhere.";
}
auto pad_mode = node_js.count("pad_mode") == 0 auto pad_mode = node_js.count("pad_mode") == 0
? op::PadMode::CONSTANT ? op::PadMode::CONSTANT
...@@ -1292,8 +1030,7 @@ static shared_ptr<ngraph::Function> ...@@ -1292,8 +1030,7 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Power: case OP_TYPEID::Power:
{ {
node = node = make_shared<op::Power>(args[0], args[1]);
make_shared<op::Power>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::PRelu: case OP_TYPEID::PRelu:
...@@ -1430,21 +1167,6 @@ static shared_ptr<ngraph::Function> ...@@ -1430,21 +1167,6 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::ScalarConstantLike>(args[0], value); node = make_shared<op::ScalarConstantLike>(args[0], value);
break; break;
} }
case OP_TYPEID::ScaleShift:
{
node = make_shared<op::ScaleShift>(args[0], args[1], args[2]);
break;
}
case OP_TYPEID::ScatterAdd:
{
node = make_shared<op::ScatterAdd>(args[0], args[1], args[2]);
break;
}
case OP_TYPEID::ScatterNDAdd:
{
node = make_shared<op::ScatterNDAdd>(args[0], args[1], args[2]);
break;
}
case OP_TYPEID::Select: case OP_TYPEID::Select:
{ {
node = make_shared<op::Select>(args[0], args[1], args[2]); node = make_shared<op::Select>(args[0], args[1], args[2]);
...@@ -1455,13 +1177,6 @@ static shared_ptr<ngraph::Function> ...@@ -1455,13 +1177,6 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::ShapeOf>(args[0]); node = make_shared<op::ShapeOf>(args[0]);
break; break;
} }
case OP_TYPEID::ShuffleChannels:
{
const auto axis = node_js.at("axis").get<size_t>();
const auto groups = node_js.at("groups").get<size_t>();
node = make_shared<op::ShuffleChannels>(args[0], axis, groups);
break;
}
case OP_TYPEID::Sigmoid: case OP_TYPEID::Sigmoid:
{ {
node = make_shared<op::Sigmoid>(args[0]); node = make_shared<op::Sigmoid>(args[0]);
...@@ -1501,38 +1216,14 @@ static shared_ptr<ngraph::Function> ...@@ -1501,38 +1216,14 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Softmax>(args[0], softmax_axes); node = make_shared<op::Softmax>(args[0], softmax_axes);
break; break;
} }
case OP_TYPEID::SpaceToDepth:
{
auto block_size = node_js.at("block_size").get<size_t>();
node = make_shared<op::SpaceToDepth>(args[0], block_size);
break;
}
case OP_TYPEID::Split:
{
const auto axis = node_js.at("axis").get<size_t>();
const auto splits = node_js.at("splits").get<vector<size_t>>();
node = make_shared<op::Split>(args[0], axis, splits);
break;
}
case OP_TYPEID::Sqrt: case OP_TYPEID::Sqrt:
{ {
node = make_shared<op::Sqrt>(args[0]); node = make_shared<op::Sqrt>(args[0]);
break; break;
} }
case OP_TYPEID::SquaredDifference:
{
node = make_shared<op::SquaredDifference>(args[0], args[1]);
break;
}
case OP_TYPEID::Squeeze:
{
node = make_shared<op::Squeeze>(args[0], args[1]);
break;
}
case OP_TYPEID::Subtract: case OP_TYPEID::Subtract:
{ {
node = make_shared<op::Subtract>( node = make_shared<op::Subtract>(args[0], args[1]);
args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Sum: case OP_TYPEID::Sum:
...@@ -1551,11 +1242,6 @@ static shared_ptr<ngraph::Function> ...@@ -1551,11 +1242,6 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Tanh>(args[0]); node = make_shared<op::Tanh>(args[0]);
break; break;
} }
case OP_TYPEID::Tile:
{
node = make_shared<op::Tile>(args[0], args[1]);
break;
}
case OP_TYPEID::TopK: case OP_TYPEID::TopK:
{ {
auto top_k_axis = node_js.at("top_k_axis").get<size_t>(); auto top_k_axis = node_js.at("top_k_axis").get<size_t>();
...@@ -1575,11 +1261,6 @@ static shared_ptr<ngraph::Function> ...@@ -1575,11 +1261,6 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::StopGradient>(args[0]); node = make_shared<op::StopGradient>(args[0]);
break; break;
} }
case OP_TYPEID::Unsqueeze:
{
node = make_shared<op::Unsqueeze>(args[0], args[1]);
break;
}
case OP_TYPEID::UnknownOp: case OP_TYPEID::UnknownOp:
{ {
stringstream ss; stringstream ss;
...@@ -1587,9 +1268,7 @@ static shared_ptr<ngraph::Function> ...@@ -1587,9 +1268,7 @@ static shared_ptr<ngraph::Function>
throw runtime_error(ss.str()); throw runtime_error(ss.str());
} }
} }
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#endif
for (const string& name : control_deps_inputs) for (const string& name : control_deps_inputs)
{ {
...@@ -1669,33 +1348,24 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1669,33 +1348,24 @@ static json write(const Node& n, bool binary_constant_data)
json control_deps = json::array(); json control_deps = json::array();
json outputs = json::array(); json outputs = json::array();
for (auto& input : n.inputs()) for (const descriptor::Input& input : n.get_inputs())
{ {
inputs.push_back(input.get_source_output().get_node()->get_name()); inputs.push_back(input.get_output().get_node()->get_name());
} }
for (auto cdep : n.get_control_dependencies()) for (auto cdep : n.get_control_dependencies())
{ {
control_deps.push_back(cdep->get_name()); control_deps.push_back(cdep->get_name());
} }
for (auto& output : n.outputs()) for (size_t i = 0; i < n.get_output_size(); ++i)
{ {
outputs.push_back(output.get_tensor().get_name()); outputs.push_back(n.get_output_tensor(i).get_name());
} }
if (!inputs.empty())
{
node["inputs"] = inputs; node["inputs"] = inputs;
}
if (!control_deps.empty())
{
node["control_deps"] = control_deps; node["control_deps"] = control_deps;
}
if (!outputs.empty())
{
node["outputs"] = outputs; node["outputs"] = outputs;
}
if (s_serialize_output_shapes_enabled) if (std::getenv("NGRAPH_SERIALIZER_OUTPUT_SHAPES") != nullptr)
{ {
json output_shapes = json::array(); json output_shapes = json::array();
for (size_t i = 0; i < n.get_output_size(); ++i) for (size_t i = 0; i < n.get_output_size(); ++i)
...@@ -1706,26 +1376,17 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1706,26 +1376,17 @@ static json write(const Node& n, bool binary_constant_data)
} }
string node_op = n.description(); string node_op = n.description();
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch" #pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum" #pragma GCC diagnostic error "-Wswitch-enum"
// #pragma GCC diagnostic error "-Wimplicit-fallthrough" // #pragma GCC diagnostic error "-Wimplicit-fallthrough"
#endif
switch (get_typeid(node_op)) switch (get_typeid(node_op))
{ {
case OP_TYPEID::Abs: { break; case OP_TYPEID::Abs: { break;
} }
case OP_TYPEID::Acos: { break; case OP_TYPEID::Acos: { break;
} }
case OP_TYPEID::Add: case OP_TYPEID::Add: { break;
{
auto tmp = dynamic_cast<const op::Add*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
} }
case OP_TYPEID::ArgMin: case OP_TYPEID::ArgMin:
{ {
...@@ -1749,14 +1410,7 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1749,14 +1410,7 @@ static json write(const Node& n, bool binary_constant_data)
} }
case OP_TYPEID::AllReduce: { break; case OP_TYPEID::AllReduce: { break;
} }
case OP_TYPEID::And: case OP_TYPEID::And: { break;
{
auto tmp = dynamic_cast<const op::And*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
} }
case OP_TYPEID::Any: case OP_TYPEID::Any:
{ {
...@@ -1776,7 +1430,6 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1776,7 +1430,6 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation(); node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation();
node["pad_type"] = tmp->get_pad_type();
break; break;
} }
case OP_TYPEID::AvgPoolBackprop: case OP_TYPEID::AvgPoolBackprop:
...@@ -1790,8 +1443,6 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1790,8 +1443,6 @@ static json write(const Node& n, bool binary_constant_data)
node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation(); node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation();
break; break;
} }
case OP_TYPEID::BatchMatMul: { break;
}
case OP_TYPEID::BatchNormTraining: case OP_TYPEID::BatchNormTraining:
{ {
auto tmp = dynamic_cast<const op::BatchNormTraining*>(&n); auto tmp = dynamic_cast<const op::BatchNormTraining*>(&n);
...@@ -1827,13 +1478,6 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1827,13 +1478,6 @@ static json write(const Node& n, bool binary_constant_data)
} }
case OP_TYPEID::Ceiling: { break; case OP_TYPEID::Ceiling: { break;
} }
case OP_TYPEID::Clamp:
{
auto tmp = dynamic_cast<const op::Clamp*>(&n);
node["min"] = tmp->get_min();
node["max"] = tmp->get_max();
break;
}
case OP_TYPEID::Concat: case OP_TYPEID::Concat:
{ {
auto tmp = dynamic_cast<const op::Concat*>(&n); auto tmp = dynamic_cast<const op::Concat*>(&n);
...@@ -1843,13 +1487,7 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1843,13 +1487,7 @@ static json write(const Node& n, bool binary_constant_data)
case OP_TYPEID::Constant: case OP_TYPEID::Constant:
{ {
auto tmp = dynamic_cast<const op::Constant*>(&n); auto tmp = dynamic_cast<const op::Constant*>(&n);
if (tmp->are_all_data_elements_bitwise_identical()) if (!binary_constant_data)
{
vector<string> vs;
vs.push_back(tmp->convert_value_to_string(0));
node["value"] = vs;
}
else
{ {
node["value"] = tmp->get_value_strings(); node["value"] = tmp->get_value_strings();
} }
...@@ -1871,7 +1509,6 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1871,7 +1509,6 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["data_dilation_strides"] = tmp->get_data_dilation_strides(); node["data_dilation_strides"] = tmp->get_data_dilation_strides();
node["pad_type"] = tmp->get_pad_type();
break; break;
} }
case OP_TYPEID::ConvolutionBackpropData: case OP_TYPEID::ConvolutionBackpropData:
...@@ -1896,38 +1533,6 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1896,38 +1533,6 @@ static json write(const Node& n, bool binary_constant_data)
node["data_dilation_strides_forward"] = tmp->get_data_dilation_strides_forward(); node["data_dilation_strides_forward"] = tmp->get_data_dilation_strides_forward();
break; break;
} }
case OP_TYPEID::ConvolutionBias:
{
auto tmp = dynamic_cast<const op::ConvolutionBias*>(&n);
node["window_movement_strides"] = tmp->get_window_movement_strides();
node["window_dilation_strides"] = tmp->get_window_dilation_strides();
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
node["data_dilation_strides"] = tmp->get_data_dilation_strides();
break;
}
case OP_TYPEID::ConvolutionBiasAdd:
{
auto tmp = dynamic_cast<const op::ConvolutionBiasAdd*>(&n);
node["window_movement_strides"] = tmp->get_window_movement_strides();
node["window_dilation_strides"] = tmp->get_window_dilation_strides();
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
node["data_dilation_strides"] = tmp->get_data_dilation_strides();
break;
}
case OP_TYPEID::ConvolutionBiasBackpropFiltersBias:
{
auto tmp = dynamic_cast<const op::ConvolutionBiasBackpropFiltersBias*>(&n);
node["filters_shape"] = tmp->get_filters_shape();
node["bias_shape"] = tmp->get_bias_shape();
node["window_movement_strides_forward"] = tmp->get_window_movement_strides_forward();
node["window_dilation_strides_forward"] = tmp->get_window_dilation_strides_forward();
node["padding_below_forward"] = tmp->get_padding_below_forward();
node["padding_above_forward"] = tmp->get_padding_above_forward();
node["data_dilation_strides_forward"] = tmp->get_data_dilation_strides_forward();
break;
}
case OP_TYPEID::Cos: { break; case OP_TYPEID::Cos: { break;
} }
case OP_TYPEID::Cosh: { break; case OP_TYPEID::Cosh: { break;
...@@ -1939,21 +1544,7 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1939,21 +1544,7 @@ static json write(const Node& n, bool binary_constant_data)
node["axes"] = tmp->get_axes(); node["axes"] = tmp->get_axes();
break; break;
} }
case OP_TYPEID::DepthToSpace: case OP_TYPEID::Divide: { break;
{
auto tmp = dynamic_cast<const op::DepthToSpace*>(&n);
node["type"] = write_element_type(tmp->get_element_type());
node["block_size"] = tmp->get_block_size();
break;
}
case OP_TYPEID::Divide:
{
auto tmp = dynamic_cast<const op::Divide*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
} }
case OP_TYPEID::Dot: case OP_TYPEID::Dot:
{ {
...@@ -1969,54 +1560,22 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1969,54 +1560,22 @@ static json write(const Node& n, bool binary_constant_data)
} }
case OP_TYPEID::DynSlice: { break; case OP_TYPEID::DynSlice: { break;
} }
case OP_TYPEID::Elu: { break;
}
case OP_TYPEID::EmbeddingLookup: { break; case OP_TYPEID::EmbeddingLookup: { break;
} }
case OP_TYPEID::Equal: case OP_TYPEID::Equal: { break;
{
auto tmp = dynamic_cast<const op::Equal*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
} }
case OP_TYPEID::Erf: { break; case OP_TYPEID::Erf: { break;
} }
case OP_TYPEID::Exp: { break; case OP_TYPEID::Exp: { break;
} }
case OP_TYPEID::FakeQuantize:
{
auto tmp = dynamic_cast<const op::FakeQuantize*>(&n);
node["levels"] = tmp->get_levels();
break;
}
case OP_TYPEID::Floor: { break; case OP_TYPEID::Floor: { break;
} }
case OP_TYPEID::Gather:
{
auto tmp = dynamic_cast<const op::Gather*>(&n);
node["axis"] = tmp->get_axis();
break;
}
case OP_TYPEID::GatherND: { break;
}
case OP_TYPEID::GetOutputElement: case OP_TYPEID::GetOutputElement:
{ {
auto tmp = dynamic_cast<const op::GetOutputElement*>(&n); auto tmp = dynamic_cast<const op::GetOutputElement*>(&n);
node["n"] = tmp->get_n(); node["n"] = tmp->get_n();
break; break;
} }
case OP_TYPEID::Gemm:
{
auto tmp = dynamic_cast<const op::Gemm*>(&n);
node["alpha"] = tmp->get_alpha();
node["beta"] = tmp->get_beta();
node["transA"] = tmp->get_transA();
node["transB"] = tmp->get_transB();
break;
}
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
{ {
auto tmp = dynamic_cast<const op::GenerateMask*>(&n); auto tmp = dynamic_cast<const op::GenerateMask*>(&n);
...@@ -2026,68 +1585,13 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -2026,68 +1585,13 @@ static json write(const Node& n, bool binary_constant_data)
node["probability"] = tmp->get_probability(); node["probability"] = tmp->get_probability();
break; break;
} }
case OP_TYPEID::Greater: case OP_TYPEID::Greater: { break;
{
auto tmp = dynamic_cast<const op::Greater*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
} }
break; case OP_TYPEID::GreaterEq: { break;
} }
case OP_TYPEID::GreaterEq: case OP_TYPEID::Less: { break;
{
auto tmp = dynamic_cast<const op::GreaterEq*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
}
case OP_TYPEID::GRN:
{
auto tmp = dynamic_cast<const op::GRN*>(&n);
node["bias"] = tmp->get_bias();
break;
} }
case OP_TYPEID::HardSigmoid: case OP_TYPEID::LessEq: { break;
{
auto tmp = dynamic_cast<const op::HardSigmoid*>(&n);
node["alpha"] = tmp->get_alpha();
node["beta"] = tmp->get_beta();
break;
}
case OP_TYPEID::GroupConvolution:
{
auto tmp = dynamic_cast<const op::GroupConvolution*>(&n);
node["window_movement_strides"] = tmp->get_window_movement_strides();
node["window_dilation_strides"] = tmp->get_window_dilation_strides();
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
node["data_dilation_strides"] = tmp->get_data_dilation_strides();
node["groups"] = tmp->get_groups();
node["pad_type"] = tmp->get_pad_type();
break;
}
case OP_TYPEID::LeakyRelu: { break;
}
case OP_TYPEID::Less:
{
auto tmp = dynamic_cast<const op::Less*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
}
case OP_TYPEID::LessEq:
{
auto tmp = dynamic_cast<const op::LessEq*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
} }
case OP_TYPEID::Log: { break; case OP_TYPEID::Log: { break;
} }
...@@ -2113,7 +1617,6 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -2113,7 +1617,6 @@ static json write(const Node& n, bool binary_constant_data)
node["window_movement_strides"] = tmp->get_window_movement_strides(); node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["pad_type"] = tmp->get_pad_type();
break; break;
} }
case OP_TYPEID::MaxPoolBackprop: case OP_TYPEID::MaxPoolBackprop:
...@@ -2125,14 +1628,7 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -2125,14 +1628,7 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
break; break;
} }
case OP_TYPEID::Maximum: case OP_TYPEID::Maximum: { break;
{
auto tmp = dynamic_cast<const op::Maximum*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
} }
case OP_TYPEID::Min: case OP_TYPEID::Min:
{ {
...@@ -2140,50 +1636,13 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -2140,50 +1636,13 @@ static json write(const Node& n, bool binary_constant_data)
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = tmp->get_reduction_axes();
break; break;
} }
case OP_TYPEID::Minimum: case OP_TYPEID::Minimum: { break;
{
auto tmp = dynamic_cast<const op::Minimum*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
}
case OP_TYPEID::Multiply:
{
auto tmp = dynamic_cast<const op::Multiply*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
} }
case OP_TYPEID::MVN: case OP_TYPEID::Multiply: { break;
{
auto tmp = dynamic_cast<const op::MVN*>(&n);
node["normalize_variance"] = tmp->get_normalize_variance();
node["across_channels"] = tmp->get_across_channels();
node["eps"] = tmp->get_eps();
break;
} }
case OP_TYPEID::Negative: { break; case OP_TYPEID::Negative: { break;
} }
case OP_TYPEID::Normalize: case OP_TYPEID::NotEqual: { break;
{
auto tmp = dynamic_cast<const op::Normalize*>(&n);
node["across_spatial"] = tmp->get_across_spatial();
node["channel_shared"] = tmp->get_channel_shared();
node["eps"] = tmp->get_eps();
break;
}
case OP_TYPEID::NotEqual:
{
auto tmp = dynamic_cast<const op::NotEqual*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
} }
case OP_TYPEID::Not: { break; case OP_TYPEID::Not: { break;
} }
...@@ -2194,14 +1653,7 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -2194,14 +1653,7 @@ static json write(const Node& n, bool binary_constant_data)
node["one_hot_axis"] = tmp->get_one_hot_axis(); node["one_hot_axis"] = tmp->get_one_hot_axis();
break; break;
} }
case OP_TYPEID::Or: case OP_TYPEID::Or: { break;
{
auto tmp = dynamic_cast<const op::Or*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
} }
case OP_TYPEID::Pad: case OP_TYPEID::Pad:
{ {
...@@ -2244,14 +1696,7 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -2244,14 +1696,7 @@ static json write(const Node& n, bool binary_constant_data)
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = tmp->get_reduction_axes();
break; break;
} }
case OP_TYPEID::Power: case OP_TYPEID::Power: { break;
{
auto tmp = dynamic_cast<const op::Power*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
} }
case OP_TYPEID::Quantize: case OP_TYPEID::Quantize:
{ {
...@@ -2344,23 +1789,10 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -2344,23 +1789,10 @@ static json write(const Node& n, bool binary_constant_data)
node["element_type"] = write_element_type(constant->get_element_type()); node["element_type"] = write_element_type(constant->get_element_type());
break; break;
} }
case OP_TYPEID::ScaleShift: { break;
}
case OP_TYPEID::ScatterAdd: { break;
}
case OP_TYPEID::ScatterNDAdd: { break;
}
case OP_TYPEID::Select: { break; case OP_TYPEID::Select: { break;
} }
case OP_TYPEID::ShapeOf: { break; case OP_TYPEID::ShapeOf: { break;
} }
case OP_TYPEID::ShuffleChannels:
{
const auto tmp = dynamic_cast<const op::ShuffleChannels*>(&n);
node["axis"] = tmp->get_axis();
node["groups"] = tmp->get_groups();
break;
}
case OP_TYPEID::Sigmoid: { break; case OP_TYPEID::Sigmoid: { break;
} }
case OP_TYPEID::SigmoidBackprop: { break; case OP_TYPEID::SigmoidBackprop: { break;
...@@ -2379,36 +1811,11 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -2379,36 +1811,11 @@ static json write(const Node& n, bool binary_constant_data)
node["strides"] = tmp->get_strides(); node["strides"] = tmp->get_strides();
break; break;
} }
case OP_TYPEID::SpaceToDepth:
{
auto tmp = dynamic_cast<const op::SpaceToDepth*>(&n);
node["type"] = write_element_type(tmp->get_element_type());
node["block_size"] = tmp->get_block_size();
break;
}
case OP_TYPEID::Split:
{
auto tmp = dynamic_cast<const op::Split*>(&n);
node["axis"] = tmp->get_axis();
node["splits"] = tmp->get_splits();
break;
}
case OP_TYPEID::Sqrt: { break; case OP_TYPEID::Sqrt: { break;
} }
case OP_TYPEID::SquaredDifference: { break;
}
case OP_TYPEID::Squeeze: { break;
}
case OP_TYPEID::StopGradient: { break; case OP_TYPEID::StopGradient: { break;
} }
case OP_TYPEID::Subtract: case OP_TYPEID::Subtract: { break;
{
auto tmp = dynamic_cast<const op::Subtract*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
} }
case OP_TYPEID::Sum: case OP_TYPEID::Sum:
{ {
...@@ -2426,8 +1833,6 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -2426,8 +1833,6 @@ static json write(const Node& n, bool binary_constant_data)
} }
case OP_TYPEID::Tanh: { break; case OP_TYPEID::Tanh: { break;
} }
case OP_TYPEID::Tile: { break;
}
case OP_TYPEID::TopK: case OP_TYPEID::TopK:
{ {
auto tmp = dynamic_cast<const op::TopK*>(&n); auto tmp = dynamic_cast<const op::TopK*>(&n);
...@@ -2439,14 +1844,10 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -2439,14 +1844,10 @@ static json write(const Node& n, bool binary_constant_data)
} }
case OP_TYPEID::Transpose: { break; case OP_TYPEID::Transpose: { break;
} }
case OP_TYPEID::Unsqueeze: { break;
}
case OP_TYPEID::UnknownOp: { break; case OP_TYPEID::UnknownOp: { break;
} }
} }
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#endif
return node; return node;
} }
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/dequantize.hpp" #include "ngraph/op/dequantize.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/quantized_concat.hpp" #include "ngraph/op/experimental/quantized_concat.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp" #include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp" #include "ngraph/op/experimental/quantized_conv_bias.hpp"
...@@ -66,7 +67,6 @@ ...@@ -66,7 +67,6 @@
#include "ngraph/runtime/cpu/op/deconv.hpp" #include "ngraph/runtime/cpu/op/deconv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp" #include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp" #include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp" #include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp" #include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp" #include "ngraph/runtime/cpu/op/rnn.hpp"
...@@ -1413,15 +1413,15 @@ TEST(cpu_fusion, backwards_maxpool_with_indices_n4_c1_hw4_2x2_max) ...@@ -1413,15 +1413,15 @@ TEST(cpu_fusion, backwards_maxpool_with_indices_n4_c1_hw4_2x2_max)
#if defined(NGRAPH_HALIDE) #if defined(NGRAPH_HALIDE)
TEST(cpu_fusion, loop_kernel_one_input_one_output_halide) TEST(cpu_fusion, compiled_kernel_one_input_one_output_halide)
{ {
Shape shapeA{2, 2}; Shape shapeA{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA); auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto relu_a = make_shared<op::Relu>(A); auto relu_a = make_shared<op::Relu>(A);
auto relu_relu_a = make_shared<op::Relu>(relu_a); auto relu_relu_a = make_shared<op::Relu>(relu_a);
auto lk = make_shared<runtime::cpu::op::LoopKernel>( auto ck = make_shared<op::CompiledKernel>(
NodeVector{relu_a, relu_relu_a}, NodeVector{relu_relu_a}, NodeVector{A}); NodeVector{relu_a, relu_relu_a}, NodeVector{relu_relu_a}, NodeVector{A});
auto f = make_shared<Function>(NodeVector{lk}, ParameterVector{A}); auto f = make_shared<Function>(NodeVector{ck}, ParameterVector{A});
auto backend = runtime::Backend::create("CPU"); auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shapeA); shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shapeA);
...@@ -1437,7 +1437,7 @@ TEST(cpu_fusion, loop_kernel_one_input_one_output_halide) ...@@ -1437,7 +1437,7 @@ TEST(cpu_fusion, loop_kernel_one_input_one_output_halide)
EXPECT_TRUE(test::all_close(read_vector<float>(result), expected)); EXPECT_TRUE(test::all_close(read_vector<float>(result), expected));
} }
TEST(cpu_fusion, loop_kernel_two_input_two_output_halide) TEST(cpu_fusion, compiled_kernel_two_input_two_output_halide)
{ {
Shape shapeA{2, 2}; Shape shapeA{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA); auto A = make_shared<op::Parameter>(element::f32, shapeA);
...@@ -1445,11 +1445,11 @@ TEST(cpu_fusion, loop_kernel_two_input_two_output_halide) ...@@ -1445,11 +1445,11 @@ TEST(cpu_fusion, loop_kernel_two_input_two_output_halide)
auto relu_a = make_shared<op::Relu>(A); auto relu_a = make_shared<op::Relu>(A);
auto add_ab = make_shared<op::Add>(relu_a, B); auto add_ab = make_shared<op::Add>(relu_a, B);
auto lk = make_shared<runtime::cpu::op::LoopKernel>( auto ck = make_shared<op::CompiledKernel>(
NodeVector{relu_a, add_ab}, NodeVector{relu_a, add_ab}, NodeVector{A, B}); NodeVector{relu_a, add_ab}, NodeVector{relu_a, add_ab}, NodeVector{A, B});
auto goe1 = make_shared<op::GetOutputElement>(lk, 0); auto goe1 = make_shared<op::GetOutputElement>(ck, 0);
auto goe2 = make_shared<op::GetOutputElement>(lk, 1); auto goe2 = make_shared<op::GetOutputElement>(ck, 1);
auto f = make_shared<Function>(NodeVector{goe1, goe2}, ParameterVector{A, B}); auto f = make_shared<Function>(NodeVector{goe1, goe2}, ParameterVector{A, B});
auto backend = runtime::Backend::create("CPU"); auto backend = runtime::Backend::create("CPU");
...@@ -1471,7 +1471,7 @@ TEST(cpu_fusion, loop_kernel_two_input_two_output_halide) ...@@ -1471,7 +1471,7 @@ TEST(cpu_fusion, loop_kernel_two_input_two_output_halide)
EXPECT_TRUE(test::all_close(read_vector<float>(result_relu), expected_relu)); EXPECT_TRUE(test::all_close(read_vector<float>(result_relu), expected_relu));
} }
TEST(cpu_fusion, loop_kernel_embedded_graph_halide) TEST(cpu_fusion, compiled_kernel_embedded_graph_halide)
{ {
Shape shapeA{2, 2}; Shape shapeA{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA); auto A = make_shared<op::Parameter>(element::f32, shapeA);
...@@ -1479,9 +1479,9 @@ TEST(cpu_fusion, loop_kernel_embedded_graph_halide) ...@@ -1479,9 +1479,9 @@ TEST(cpu_fusion, loop_kernel_embedded_graph_halide)
auto neg_a = make_shared<op::Negative>(A); auto neg_a = make_shared<op::Negative>(A);
auto neg_b = make_shared<op::Negative>(B); auto neg_b = make_shared<op::Negative>(B);
auto add = neg_a + neg_b; auto add = neg_a + neg_b;
auto lk = make_shared<runtime::cpu::op::LoopKernel>( auto ck =
NodeVector{add}, NodeVector{add}, NodeVector{neg_a, neg_b}); make_shared<op::CompiledKernel>(NodeVector{add}, NodeVector{add}, NodeVector{neg_a, neg_b});
auto f = make_shared<Function>(NodeVector{lk}, ParameterVector{A, B}); auto f = make_shared<Function>(NodeVector{ck}, ParameterVector{A, B});
auto backend = runtime::Backend::create("CPU"); auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shapeA); shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shapeA);
...@@ -1498,15 +1498,14 @@ TEST(cpu_fusion, loop_kernel_embedded_graph_halide) ...@@ -1498,15 +1498,14 @@ TEST(cpu_fusion, loop_kernel_embedded_graph_halide)
EXPECT_TRUE(test::all_close_f(read_vector<float>(result), expected, MIN_FLOAT_TOLERANCE_BITS)); EXPECT_TRUE(test::all_close_f(read_vector<float>(result), expected, MIN_FLOAT_TOLERANCE_BITS));
} }
TEST(cpu_fusion, loop_kernel_two_inputs_one_output_halide) TEST(cpu_fusion, compiled_kernel_two_inputs_one_output_halide)
{ {
Shape shapeA{2, 2}; Shape shapeA{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA); auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto B = make_shared<op::Parameter>(element::f32, shapeA); auto B = make_shared<op::Parameter>(element::f32, shapeA);
auto add = A + B; auto add = A + B;
auto lk = make_shared<runtime::cpu::op::LoopKernel>( auto ck = make_shared<op::CompiledKernel>(NodeVector{add}, NodeVector{add}, NodeVector{A, B});
NodeVector{add}, NodeVector{add}, NodeVector{A, B}); auto f = make_shared<Function>(NodeVector{ck}, ParameterVector{A, B});
auto f = make_shared<Function>(NodeVector{lk}, ParameterVector{A, B});
auto backend = runtime::Backend::create("CPU"); auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shapeA); shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shapeA);
...@@ -1525,7 +1524,7 @@ TEST(cpu_fusion, loop_kernel_two_inputs_one_output_halide) ...@@ -1525,7 +1524,7 @@ TEST(cpu_fusion, loop_kernel_two_inputs_one_output_halide)
EXPECT_TRUE(test::all_close_f(read_vector<float>(result), expected, MIN_FLOAT_TOLERANCE_BITS)); EXPECT_TRUE(test::all_close_f(read_vector<float>(result), expected, MIN_FLOAT_TOLERANCE_BITS));
} }
TEST(cpu_fusion, loop_kernel_multiple_outputs_halide) TEST(cpu_fusion, compiled_kernel_multiple_outputs_halide)
{ {
Shape shapeA{2, 2}; Shape shapeA{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA); auto A = make_shared<op::Parameter>(element::f32, shapeA);
...@@ -1542,13 +1541,13 @@ TEST(cpu_fusion, loop_kernel_multiple_outputs_halide) ...@@ -1542,13 +1541,13 @@ TEST(cpu_fusion, loop_kernel_multiple_outputs_halide)
auto add_aab = add_ab_abs + A; auto add_aab = add_ab_abs + A;
auto add_cdd = add_cd_abs + D; auto add_cdd = add_cd_abs + D;
auto lk = make_shared<runtime::cpu::op::LoopKernel>( auto ck = make_shared<op::CompiledKernel>(
NodeVector{neg_a, neg_b, add_ab, add_cd, add_cd_abs, add_ab_abs, add_aab, add_cdd}, NodeVector{neg_a, neg_b, add_ab, add_cd, add_cd_abs, add_ab_abs, add_aab, add_cdd},
NodeVector{add_aab, add_cdd, neg_b}, NodeVector{add_aab, add_cdd, neg_b},
NodeVector{A, B, C, D}); NodeVector{A, B, C, D});
auto add_aab_goe = std::make_shared<op::GetOutputElement>(lk, 0); auto add_aab_goe = std::make_shared<op::GetOutputElement>(ck, 0);
auto add_cdd_goe = std::make_shared<op::GetOutputElement>(lk, 1); auto add_cdd_goe = std::make_shared<op::GetOutputElement>(ck, 1);
auto neg_b_goe = std::make_shared<op::GetOutputElement>(lk, 2); auto neg_b_goe = std::make_shared<op::GetOutputElement>(ck, 2);
auto f = make_shared<Function>(NodeVector{add_aab_goe, add_cdd_goe, neg_b_goe}, auto f = make_shared<Function>(NodeVector{add_aab_goe, add_cdd_goe, neg_b_goe},
ParameterVector{A, B, C, D}); ParameterVector{A, B, C, D});
...@@ -1583,7 +1582,7 @@ TEST(cpu_fusion, loop_kernel_multiple_outputs_halide) ...@@ -1583,7 +1582,7 @@ TEST(cpu_fusion, loop_kernel_multiple_outputs_halide)
EXPECT_TRUE(test::all_close_f(read_vector<float>(r3), expected3, MIN_FLOAT_TOLERANCE_BITS)); EXPECT_TRUE(test::all_close_f(read_vector<float>(r3), expected3, MIN_FLOAT_TOLERANCE_BITS));
} }
TEST(cpu_fusion, loop_kernel_copy_with_new_args) TEST(cpu_fusion, compiled_kernel_copy_with_new_args)
{ {
Shape shapeA{2, 2}; Shape shapeA{2, 2};
auto A = make_shared<op::Parameter>(element::i32, shapeA); auto A = make_shared<op::Parameter>(element::i32, shapeA);
...@@ -1600,13 +1599,13 @@ TEST(cpu_fusion, loop_kernel_copy_with_new_args) ...@@ -1600,13 +1599,13 @@ TEST(cpu_fusion, loop_kernel_copy_with_new_args)
auto add_aab = add_ab_abs + A; auto add_aab = add_ab_abs + A;
auto add_cdd = add_cd_abs + D; auto add_cdd = add_cd_abs + D;
auto lk = make_shared<runtime::cpu::op::LoopKernel>( auto ck = make_shared<op::CompiledKernel>(
NodeVector{neg_a, neg_b, add_ab, add_cd, add_cd_abs, add_ab_abs, add_aab, add_cdd}, NodeVector{neg_a, neg_b, add_ab, add_cd, add_cd_abs, add_ab_abs, add_aab, add_cdd},
NodeVector{add_aab, add_cdd, neg_b}, NodeVector{add_aab, add_cdd, neg_b},
NodeVector{A, B, C, D}); NodeVector{A, B, C, D});
auto add_aab_goe = std::make_shared<op::GetOutputElement>(lk, 0); auto add_aab_goe = std::make_shared<op::GetOutputElement>(ck, 0);
auto add_cdd_goe = std::make_shared<op::GetOutputElement>(lk, 1); auto add_cdd_goe = std::make_shared<op::GetOutputElement>(ck, 1);
auto neg_b_goe = std::make_shared<op::GetOutputElement>(lk, 2); auto neg_b_goe = std::make_shared<op::GetOutputElement>(ck, 2);
auto f = make_shared<Function>(NodeVector{add_aab_goe, add_cdd_goe, neg_b_goe}, auto f = make_shared<Function>(NodeVector{add_aab_goe, add_cdd_goe, neg_b_goe},
ParameterVector{A, B, C, D}); ParameterVector{A, B, C, D});
...@@ -2167,7 +2166,7 @@ TEST(cpu_fusion, rnn_fprop_1_lstm_cell) ...@@ -2167,7 +2166,7 @@ TEST(cpu_fusion, rnn_fprop_1_lstm_cell)
#if 0 #if 0
TEST(cpu_fusion, loop_kernel_fusion_multiple_groups_pruned) TEST(cpu_fusion, compiled_kernel_fusion_multiple_groups_pruned)
{ {
auto make_function = []() -> std::shared_ptr<Function> { auto make_function = []() -> std::shared_ptr<Function> {
Shape shape{}; Shape shape{};
...@@ -2192,15 +2191,15 @@ TEST(cpu_fusion, loop_kernel_fusion_multiple_groups_pruned) ...@@ -2192,15 +2191,15 @@ TEST(cpu_fusion, loop_kernel_fusion_multiple_groups_pruned)
}; };
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPULoopKernelFusion>(3); pass_manager.register_pass<runtime::cpu::pass::CPUCompiledKernelFusion>(3);
auto cpu_f = make_function(); auto cpu_f = make_function();
auto int_f = make_function(); auto int_f = make_function();
pass_manager.run_passes(cpu_f); pass_manager.run_passes(cpu_f);
test::Uniform<float> rng(-100.0f, 100.0f); test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args; vector<vector<float>> args;
size_t lkn = count_ops_of_type<runtime::cpu::op::LoopKernel>(cpu_f); size_t ckn = count_ops_of_type<op::CompiledKernel>(cpu_f);
ASSERT_GT(lkn, 0); ASSERT_GT(ckn, 0);
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters()) for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{ {
...@@ -2216,7 +2215,7 @@ TEST(cpu_fusion, loop_kernel_fusion_multiple_groups_pruned) ...@@ -2216,7 +2215,7 @@ TEST(cpu_fusion, loop_kernel_fusion_multiple_groups_pruned)
} }
} }
TEST(cpu_fusion, loop_kernel_fusion_bounded_relu) TEST(cpu_fusion, compiled_kernel_fusion_bounded_relu)
{ {
auto make_function = []() -> std::shared_ptr<Function> { auto make_function = []() -> std::shared_ptr<Function> {
Shape shape{}; Shape shape{};
...@@ -2235,7 +2234,7 @@ TEST(cpu_fusion, loop_kernel_fusion_bounded_relu) ...@@ -2235,7 +2234,7 @@ TEST(cpu_fusion, loop_kernel_fusion_bounded_relu)
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before_relu_fusion.png"); pass_manager.register_pass<pass::VisualizeTree>("before_relu_fusion.png");
pass_manager.register_pass<runtime::cpu::pass::CPULoopKernelFusion>(3); pass_manager.register_pass<runtime::cpu::pass::CPUCompiledKernelFusion>(3);
pass_manager.register_pass<pass::VisualizeTree>("after_relu_fusion.png"); pass_manager.register_pass<pass::VisualizeTree>("after_relu_fusion.png");
auto cpu_f = make_function(); auto cpu_f = make_function();
auto int_f = make_function(); auto int_f = make_function();
...@@ -2243,8 +2242,8 @@ TEST(cpu_fusion, loop_kernel_fusion_bounded_relu) ...@@ -2243,8 +2242,8 @@ TEST(cpu_fusion, loop_kernel_fusion_bounded_relu)
test::Uniform<float> rng(-100.0f, 100.0f); test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args; vector<vector<float>> args;
size_t lkn = count_ops_of_type<runtime::cpu::op::LoopKernel>(cpu_f); size_t ckn = count_ops_of_type<op::CompiledKernel>(cpu_f);
ASSERT_GT(lkn, 0); ASSERT_GT(ckn, 0);
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters()) for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{ {
...@@ -2260,7 +2259,7 @@ TEST(cpu_fusion, loop_kernel_fusion_bounded_relu) ...@@ -2260,7 +2259,7 @@ TEST(cpu_fusion, loop_kernel_fusion_bounded_relu)
} }
} }
TEST(cpu_fusion, loop_kernel_fusion_multiple_groups) TEST(cpu_fusion, compiled_kernel_fusion_multiple_groups)
{ {
auto make_function = []() -> std::shared_ptr<Function> { auto make_function = []() -> std::shared_ptr<Function> {
Shape shape{}; Shape shape{};
...@@ -2285,15 +2284,15 @@ TEST(cpu_fusion, loop_kernel_fusion_multiple_groups) ...@@ -2285,15 +2284,15 @@ TEST(cpu_fusion, loop_kernel_fusion_multiple_groups)
}; };
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPULoopKernelFusion>(2); pass_manager.register_pass<runtime::cpu::pass::CPUCompiledKernelFusion>(2);
auto cpu_f = make_function(); auto cpu_f = make_function();
auto int_f = make_function(); auto int_f = make_function();
pass_manager.run_passes(cpu_f); pass_manager.run_passes(cpu_f);
test::Uniform<float> rng(-100.0f, 100.0f); test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args; vector<vector<float>> args;
size_t lkn = count_ops_of_type<runtime::cpu::op::LoopKernel>(cpu_f); size_t ckn = count_ops_of_type<op::CompiledKernel>(cpu_f);
ASSERT_GT(lkn, 0); ASSERT_GT(ckn, 0);
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters()) for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{ {
...@@ -2309,7 +2308,7 @@ TEST(cpu_fusion, loop_kernel_fusion_multiple_groups) ...@@ -2309,7 +2308,7 @@ TEST(cpu_fusion, loop_kernel_fusion_multiple_groups)
} }
} }
TEST(cpu_fusion, loop_kernel_fusion_one_group) TEST(cpu_fusion, compiled_kernel_fusion_one_group)
{ {
auto make_function = []() -> std::shared_ptr<Function> { auto make_function = []() -> std::shared_ptr<Function> {
Shape shape{}; Shape shape{};
...@@ -2335,15 +2334,15 @@ TEST(cpu_fusion, loop_kernel_fusion_one_group) ...@@ -2335,15 +2334,15 @@ TEST(cpu_fusion, loop_kernel_fusion_one_group)
}; };
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPULoopKernelFusion>(2); pass_manager.register_pass<runtime::cpu::pass::CPUCompiledKernelFusion>(2);
auto cpu_f = make_function(); auto cpu_f = make_function();
auto int_f = make_function(); auto int_f = make_function();
pass_manager.run_passes(cpu_f); pass_manager.run_passes(cpu_f);
test::Uniform<float> rng(-100.0f, 100.0f); test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args; vector<vector<float>> args;
size_t lkn = count_ops_of_type<runtime::cpu::op::LoopKernel>(cpu_f); size_t ckn = count_ops_of_type<op::CompiledKernel>(cpu_f);
ASSERT_GT(lkn, 0); ASSERT_GT(ckn, 0);
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters()) for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment