Commit 2e88d948 authored by Sandeep's avatar Sandeep Committed by Robert Kimball

hybrid at core (#1821)

* skeleton backend

* Code owner from if conditioning

* add simple placement for interpreter and register pass in hybrid

* placement policy applied

* clone the function if needed

* split the function

* Compile subfunctions in corresponding backends

* hybrid backed works as is for abc test

* cleanup

* add placement policy for CPU

* cleanup a little

* add simple op cost method to backend

* enable CPU pass via flag

* address clang-format PR issue

* reslove build

* clean-up

* update manifest

* disable HYBRID as default build

* style

* addressing offline discussion

* more offline discussion
parent cf15ef32
......@@ -92,6 +92,7 @@ option(NGRAPH_CPU_ENABLE "Control the building of the CPU backend" TRUE)
option(NGRAPH_INTELGPU_ENABLE "Control the building of the Intel GPU backend with clDNN" FALSE)
option(NGRAPH_GPU_ENABLE "Control the building of the GPU backend" FALSE)
option(NGRAPH_INTERPRETER_ENABLE "Control the building of the INTERPRETER backend" TRUE)
option(NGRAPH_HYBRID_ENABLE "Control the building of the HYBRID backend" FALSE)
option(NGRAPH_DISTRIBUTED_ENABLE "Add distributed mode to the CPU backend" FALSE)
option(NGRAPH_DEBUG_ENABLE "Enable output for NGRAPH_DEBUG statements" FALSE)
option(NGRAPH_ONNX_IMPORT_ENABLE "Enable ONNX importer" FALSE)
......
......@@ -40,6 +40,7 @@
/src/ngraph/runtime/ @rkimballn1 @Krovatkin
/src/ngraph/runtime/cpu/ @jbobba
/src/ngraph/runtime/gpu/ @rkimballn1
/src/ngraph/runtime/hybrid/ @sasadep
/src/ngraph/runtime/intelgpu/ @shssf
/src/ngraph/runtime/interpreter/ @rkimballn1
/src/ngraph/runtime/reference/ @aprocter
......
......@@ -16,6 +16,11 @@
add_subdirectory(interpreter)
if (NGRAPH_HYBRID_ENABLE)
add_subdirectory(hybrid)
endif()
if (NGRAPH_CPU_ENABLE)
add_subdirectory(cpu)
endif()
......
......@@ -107,3 +107,10 @@ void runtime::Backend::validate_call(shared_ptr<const Function> function,
}
}
}
bool runtime::Backend::is_supported(const Node& node) const
{
// The default behavior is that a backend fully supports all ops. If this is not the case
// then override this method and enhance.
return false;
}
......@@ -118,6 +118,11 @@ public:
virtual std::vector<PerformanceCounter>
get_performance_data(std::shared_ptr<Function> func) const;
/// \brief Test if a backend is capable of supporting an op
/// \param node is the op to test.
/// \returns true if the op is supported, false otherwise.
virtual bool is_supported(const Node& node) const;
protected:
void validate_call(std::shared_ptr<const Function> func,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
......
# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
if (NGRAPH_HYBRID_ENABLE)
add_library(hybrid_backend SHARED hybrid_backend.cpp)
set_target_properties(hybrid_backend PROPERTIES VERSION ${NGRAPH_VERSION})
target_link_libraries(hybrid_backend PUBLIC ngraph)
set_target_properties(hybrid_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR})
install(TARGETS hybrid_backend
LIBRARY DESTINATION "${NGRAPH_INSTALL_LIB}"
ARCHIVE DESTINATION "${NGRAPH_INSTALL_LIB}"
)
endif()
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <sstream>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <vector>
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/except.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/assign_placement.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/hybrid/hybrid_backend.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
using descriptor::layout::DenseTensorLayout;
extern "C" const char* get_ngraph_version_string()
{
return NGRAPH_VERSION;
}
extern "C" runtime::Backend* new_backend(const char* configuration_string)
{
return new runtime::hybrid::HYBRIDBackend();
}
extern "C" void delete_backend(runtime::Backend* backend)
{
delete backend;
}
template <typename T>
void copy_data(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<T>& data)
{
size_t data_size = data.size() * sizeof(T);
tv->write(data.data(), 0, data_size);
}
template <typename T>
std::vector<T> read_vector(std::shared_ptr<ngraph::runtime::Tensor> tv)
{
if (ngraph::element::from<T>() != tv->get_tensor_layout()->get_element_type())
{
throw std::invalid_argument("read_vector type must match Tensor type");
}
size_t element_count = ngraph::shape_size(tv->get_shape());
size_t size = element_count * sizeof(T);
std::vector<T> rc(element_count);
tv->read(rc.data(), 0, size);
return rc;
}
shared_ptr<runtime::Backend> runtime::hybrid::HYBRIDBackend::get_cached_backend(Placement placement)
{
if (m_cached_backends.find(placement) == m_cached_backends.end())
{
m_cached_backends[placement] = runtime::Backend::create(placement_to_string(placement));
}
return m_cached_backends.at(placement);
}
shared_ptr<runtime::Tensor> runtime::hybrid::HYBRIDBackend::create_tensor(const element::Type& type,
const Shape& shape)
{
return make_shared<runtime::HostTensor>(type, shape, "external");
}
shared_ptr<runtime::Tensor> runtime::hybrid::HYBRIDBackend::create_tensor(const element::Type& type,
const Shape& shape,
void* memory_pointer)
{
return make_shared<runtime::HostTensor>(type, shape, memory_pointer, "external");
}
bool runtime::hybrid::HYBRIDBackend::compile(shared_ptr<Function> function)
{
if (m_function_map.find(function) == m_function_map.end())
{
// Clone function
FunctionInstance instance;
instance.m_function = clone_function(*function);
pass::Manager pass_manager;
pass_manager.run_passes(instance.m_function);
}
return true;
}
bool runtime::hybrid::HYBRIDBackend::call(shared_ptr<Function> function,
const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
validate_call(function, outputs, inputs);
compile(function);
return true;
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/tensor.hpp"
namespace ngraph
{
namespace runtime
{
namespace hybrid
{
class HYBRIDBackend : public runtime::Backend
{
public:
std::shared_ptr<Tensor> create_tensor(const element::Type& type,
const Shape& shape,
void* memory_pointer) override;
std::shared_ptr<Tensor> create_tensor(const element::Type& type,
const Shape& shape) override;
bool compile(std::shared_ptr<Function> function) override;
bool call(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<Tensor>>& outputs,
const std::vector<std::shared_ptr<Tensor>>& intputs) override;
private:
class FunctionInstance
{
public:
std::shared_ptr<Function> m_function;
std::vector<std::shared_ptr<Function>> m_sub_functions;
std::unordered_map<std::shared_ptr<op::Parameter>, std::shared_ptr<op::Result>>
m_map_parameter_to_result;
};
std::shared_ptr<runtime::Backend> get_cached_backend(Placement placement);
std::map<Placement, std::shared_ptr<runtime::Backend>> m_cached_backends;
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
};
}
}
}
abc
abc_int64
abs
acos
add
add_overload
aliased_output
argmax_3D_axis_0
argmax_3D_axis_1
argmax_3D_axis_2
argmax_4D_axis_3
argmax_trivial
argmin_4D_axis_3
argmin_trivial
asin
atan
avg_pool_1d_1channel_1image
avg_pool_1d_1channel_2image
avg_pool_1d_2channel_2image
avg_pool_2d_1channel_1image_padded_do_not_include_in_computation
avg_pool_2d_1channel_1image_padded_include_in_computation
avg_pool_2d_1channel_1image_strided
avg_pool_2d_2channel_2image
avg_pool_2d_2channel_2image_3x3_padded_do_not_include_in_computation
avg_pool_2d_2channel_2image_3x3_padded_include_in_computation
avg_pool_2d_2channel_2image_3x3_strided_padded_do_not_include_in_computation
avg_pool_2d_2channel_2image_3x3_strided_padded_include_in_computation
avg_pool_2d_2channel_2image_3x3_strided_uneven_padded_do_not_include_in_computation
avg_pool_2d_2channel_2image_3x3_strided_uneven_padded_include_in_computation
avg_pool_2d_2channel_2image_padded_do_not_include_in_computation
avg_pool_2d_2channel_2image_padded_include_in_computation
avg_pool_2d_2channel_2image_padded_only_above_do_not_include_in_computation
avg_pool_2d_2channel_2image_padded_only_above_include_in_computation
avg_pool_2d_2channel_2image_padded_only_below_do_not_include_in_computation
avg_pool_2d_2channel_2image_padded_only_below_include_in_computation
avg_pool_3d_strided_uneven_padded_do_not_include_in_computation
avg_pool_3d_uneven_strided_padded_include_in_computation
backwards_abc
backwards_abs
backwards_acos
backwards_add
backwards_add_nested
backwards_asin
backwards_atan
backwards_avgpool_n1_c1_hw2x2
backwards_avgpool_n1_c1_hw4x4
backwards_avgpool_n2_c2_hw2x2_win_2x2_str_1x1_padding_numeric
backwards_avgpool_n2_c2_hw4x4
backwards_avgpool_n2_c2_hw4x4_numeric
backwards_avgpool_n2_c2_hw4x4_win_2x2_str_1x1_numeric
backwards_batch_norm_three_outputs
backwards_broadcast0
backwards_broadcast1
backwards_ceiling
backwards_concat_axis_0
backwards_concat_axis_1
backwards_concat_vector
backwards_cos
backwards_cosh
backwards_divide
backwards_dot_scalar_scalar
backwards_dot_scalar_tensor
backwards_dot_tensor_scalar
backwards_dot_tensor_vector
backwards_dot_tensor2_tensor2
backwards_dot_tensor3_tensor3
backwards_dot_vector_vector
backwards_exp
backwards_floor
backwards_log
backwards_maximum
backwards_maxpool_n2_c1_hw5_3x3_str2_max
backwards_maxpool_n2_c1_hw5_3x3_str2_max_pad1x2_2x3
backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2
backwards_maxpool_n4_c1_hw4_2x2_max
backwards_maxpool_n4c1h4w4_kh2kw2_sh1sw1
backwards_minimum
backwards_multiply
backwards_negative
backwards_parameter
backwards_power
backwards_relu
backwards_replace_slice
backwards_reshape
backwards_reverse_3d_02
backwards_reverse_sequence_n3_c2_h3
backwards_reverse_sequence_n4d2c3h2w2
backwards_select
backwards_select_nested
backwards_sigmoid
backwards_sign
backwards_sin
backwards_sinh
backwards_slice
backwards_softmax_3d
backwards_softmax_all
backwards_softmax_axis
backwards_softmax_underflow
backwards_subtract
backwards_sum_m2s
backwards_sum_m2v_0
backwards_sum_m2v_1
backwards_sum_v2s
backwards_tan
backwards_tanh
batch_norm_one_output
batch_norm_three_outputs
batchnorm_bprop_n4c3h2w2
batchnorm_fprop_b1c2h2w2
batchnorm_fprop_b2c2h2w1
batchnorm_fprop_globalstats_b2c2w2h1
batchnorm_fprop_inference_b2c2h2w1
broadcast_algo_3d_backward
broadcast_algo_3d_stride_1
broadcast_algo_3d_stride_2
broadcast_algo_matrix_backward_4
broadcast_algo_matrix_stride_1
broadcast_algo_matrix_stride_2
broadcast_algo_matrix_stride_3
broadcast_algo_scalar
broadcast_algo_vector_backward_2
broadcast_algo_vector_backward_3
broadcast_algo_vector_backward_4
broadcast_algo_vector_forward_2
broadcast_algo_vector_forward_3
broadcast_algo_vector_forward_4
broadcast_algo_vector_middle
broadcast_matrix_0
broadcast_matrix_1
broadcast_matrix_2
broadcast_scalar_matrix
broadcast_scalar_tensor
broadcast_scalar_to_matrix_int32
broadcast_scalar_to_matrix_int64
broadcast_scalar_vector
broadcast_to_non_existent_axis
broadcast_trivial
broadcast_vector_colwise
broadcast_vector_rowwise
broadcast_vector_rowwise_int64
broadcast_vector_rowwise_reversed
ceiling
computation_reuse
concat_2d_tensor
concat_4d_tensor
concat_5d
concat_matrix_colwise
concat_matrix_int64
concat_matrix_rowwise
concat_vector
concat_zero_length_1d_last
concat_zero_length_1d_middle
concat_zero_length_4d_middle
constant_broadcast
constant_equality_bool
constant_multi_use
convert_float32_bool
convert_int32_bool
convert_int32_float32
convert_uint16_float32
convolution_2d_1item
convolution_2d_1item_1o1i_data_dilated
convolution_2d_1item_2o1i_data_dilated
convolution_2d_1item_2o2i_data_dilated
convolution_2d_1item_5o3i_data_dilated
convolution_2d_1item_padded_1_1x1_1
convolution_2d_1item_padded_2_3x4_5
convolution_2d_2item_5o3i_data_dilated
convolution_2d_2items
convolution_2d_2items_dilated
convolution_2d_2items_dilated_padded
convolution_2d_2items_strided
convolution_2d_2items_strided_padded
convolution_2d_2items_strided_padded_same
convolution_2d_8item_large_5o3i_data_dilated
convolution_2d_8item_large_5o3i_uneven_filter_data_dilated
convolution_2d_8item_large_5o3i_uneven_filter_uneven_data_dilation_data_dilated
convolution_3d_1item_large_5o3i_padded_uneven_filter_uneven_data_dilation_data_dilated
convolution_3d_2item_large_5o3i_padded_strided_uneven_filter_uneven_data_dilation_data_dilated
convolution_3d_2item_large_5o3i_padded_strided_uneven_filter_uneven_data_dilation_filter_dilated_data_dilated
convolution_3d_2item_large_5o3i_uneven_filter_uneven_data_dilation_data_dilated
convolution_3d_2items
convolution_4d_2items
convolution_4d_4items
convolution_4d_4items_dilated
convolution_4d_4items_padded_neg
convolution_4d_4items_strided
convolution_4d_4items_strided_dilated
convolution_4d_4items_strided_dilated_padded
convolution_4d_4items_strided_dilated_padded_neg
convolution_4d_4items_strided_dilated_padded_same
convolution_outlining
cos
cosh
dequantize
dequantize_axes
dequantize_int8
divide
divide_adjoint_stability
divide_by_zero_float32
divide_by_zero_int32
divide_overload
dot_0_0
dot_2x0_0
dot_3d_multi_axis
dot_3d_one_axis_arbitrary
dot_4d_5d_multi_axis
dot_4d_5d_multi_axis_more
dot_matrix_0x2_2x0
dot_matrix_2x0_0x2
dot_matrix_3x2_2x0
dot_matrix_vector
dot_matrix_vector_4_3
dot_matrix_vector_int64
dot_scalar_0x2
dot_scalar_scalar
dot_scalar_tensor_arg0
dot_scalar_tensor_arg1
dot1d
dot2d
dot3d_2d
dot3d_3d
equal
exp
floor
function_call
function_name
fuse_max_with_constant_zero_input_as_relu
greater
greatereq
kahan_sum_3d_to_vector
kahan_sum_to_scalar
less
lesseq
lesseq_bool
log
logical_and
logical_or
lrn
max_3d_eliminate_zero_dim
max_3d_to_matrix_least_sig
max_3d_to_matrix_most_sig
max_3d_to_scalar
max_3d_to_vector
max_matrix_cols_zero
max_matrix_columns
max_matrix_rows
max_matrix_rows_zero
max_matrix_to_scalar_zero_by_zero
max_pool_1d_1channel_1image
max_pool_1d_1channel_2image
max_pool_1d_2channel_2image
max_pool_2d_1channel_1image_overpadded
max_pool_2d_1channel_1image_padded
max_pool_2d_1channel_1image_padded_negative_values
max_pool_2d_1channel_1image_strided
max_pool_2d_2channel_2image
max_pool_2d_2channel_2image_asym_pad
max_pool_3d
max_to_scalar
max_trivial
max_trivial_5d
max_vector_zero
maximum
maximum_int32
maximum_int64
min_3d_eliminate_zero_dim
min_3d_to_matrix_least_sig
min_3d_to_matrix_most_sig
min_3d_to_scalar
min_3d_to_vector
min_matrix_cols_zero
min_matrix_columns
min_matrix_rows
min_matrix_rows_zero
min_matrix_to_scalar_zero_by_zero
min_to_scalar
min_trivial
min_trivial_5d
min_vector_zero
minimum
minimum_int32
minimum_int64
multiple_backends
multiple_result
multiply
multiply_overload
negative
node_name
not
notequal
numeric_double_inf
numeric_double_nan
numeric_float_inf
numeric_float_nan
one_hot_matrix_0
one_hot_scalar_0_in_3
one_hot_scalar_1_in_3
one_hot_scalar_2_in_3
one_hot_scalar_fp_nonint_in_3
one_hot_scalar_oob_in_3
one_hot_vector_0
one_hot_vector_1
one_hot_vector_1_barely_oob
one_hot_vector_1_far_oob
one_hot_vector_1_fp
one_hot_vector_1_fp_nonint
pad_2channel_2image_asym
pad_exterior_1d
pad_exterior_2d_0x0
pad_exterior_2d_0x3
pad_exterior_2d_3x0
pad_exterior_4d_1x2x2x2
pad_interior_1d
pad_interior_exterior_1d
pad_interior_exterior_2d
pad_interior_exterior_4d_2x0x3x2
parameter_as_output
power
product_3d_eliminate_zero_dim
product_3d_to_matrix_least_sig
product_3d_to_matrix_most_sig
product_3d_to_scalar
product_3d_to_vector
product_matrix_cols_zero
product_matrix_columns
product_matrix_rows
product_matrix_rows_zero
product_matrix_to_scalar_zero_by_zero
product_to_scalar
product_trivial
product_trivial_5d
product_vector_zero
quantize
quantize_axes
quantize_clamp
quantize_int8
reduce_3d_to_vector
reduce_matrix_cols_zero
reduce_matrix_columns
reduce_matrix_rows
reduce_matrix_rows_zero
reduce_matrix_to_scalar_zero_by_zero
reduce_to_scalar
reduce_trivial
reduce_vector_zero
reduce_window_emulating_max_pool_1d_1channel_1image
reduce_window_emulating_max_pool_1d_1channel_2image
reduce_window_emulating_max_pool_1d_2channel_2image
reduce_window_emulating_max_pool_2d_1channel_1image_strided
reduce_window_emulating_max_pool_2d_2channel_2image
relu_2Dbackprop
relu_2Dfprop
relu_4Dbackprop
relu_4Dfprop
replace_slice_3d
replace_slice_3d_strided
replace_slice_3d_strided_different_strides
replace_slice_matrix
replace_slice_matrix_inplace
replace_slice_scalar
replace_slice_vector
reshape_3d_transpose_021
reshape_3d_transpose_102
reshape_3d_transpose_120
reshape_3d_transpose_201
reshape_3d_transpose_210
reshape_4d_no_transpose
reshape_4d_transpose
reshape_6d
reshape_m2m_dim_change_transpose
reshape_m2m_same
reshape_m2m_transpose
reshape_s2t
reshape_s2t1
reshape_t2s_012
reshape_t2s_120
reshape_t2v_012
reshape_transposed_shape_change
reshape_v2m_col
reshape_v2m_row
reshape_v2t_middle
reverse_0d
reverse_1d_0
reverse_1d_nochange
reverse_2d_0
reverse_2d_01
reverse_2d_1
reverse_2d_nochange
reverse_3d_0
reverse_3d_01
reverse_3d_012
reverse_3d_02
reverse_3d_1
reverse_3d_12
reverse_3d_2
reverse_3d_nochange
reverse_sequence_n2c3h4w2
reverse_sequence_n4c3h2w2
reverse_sequence_n4d2c3h2w2
scalar_constant_float32
scalar_constant_int64
select
select_and_scatter_3d_without_overlap
select_and_scatter_with_overlap
select_and_scatter_without_overlap
sigmoid_bprop_n1c1h4
sigmoid_n1c1h2w2
sigmoid_n1c1h4
sign
sin
sinh
slice_3d
slice_3d_strided
slice_3d_strided_different_strides
slice_matrix
slice_matrix_strided
slice_scalar
slice_vector
softmax_all
softmax_axis
softmax_axis_2
softmax_axis_3d
softmax_axis_3d_trivial
softmax_underflow
sqrt
subtract
subtract_overload
sum_3d_eliminate_zero_dim
sum_3d_to_matrix_least_sig
sum_3d_to_matrix_most_sig
sum_3d_to_scalar
sum_3d_to_vector
sum_5d_to_scalar
sum_large_1d_to_scalar
sum_matrix_6d
sum_matrix_cols_zero
sum_matrix_columns
sum_matrix_rows
sum_matrix_rows_zero
sum_matrix_to_scalar_zero_by_zero
sum_to_scalar
sum_trivial
sum_trivial_5d
sum_vector_zero
tan
tanh
tensor_2constant
tensor_constant
tensor_constant_float32
tensor_constant_int64
tensor_constant_with_op
tensorview_custom_mem
topk_1d_max_all
topk_1d_max_one
topk_1d_max_partial
topk_1d_min_all
topk_1d_min_one
topk_1d_min_partial
topk_2d_max_all
topk_2d_max_one
topk_2d_max_partial
topk_2d_min_all
topk_2d_min_one
topk_2d_min_partial
topk_3d_max_all
topk_3d_max_one
topk_3d_max_partial
topk_3d_min_all
topk_3d_min_one
topk_3d_min_partial
unhandled_op
validate_call_input_count
validate_call_input_shape
validate_call_input_type
validate_call_output_count
validate_call_output_shape
validate_call_output_type
zero_sized_abs
zero_sized_acos
zero_sized_add
zero_sized_asin
zero_sized_atan
zero_sized_ceiling
zero_sized_cos
zero_sized_cosh
zero_sized_divide
zero_sized_eq
zero_sized_exp
zero_sized_floor
zero_sized_greater
zero_sized_greatereq
zero_sized_less
zero_sized_lesseq
zero_sized_log
zero_sized_maximum
zero_sized_minimum
zero_sized_multiply
zero_sized_negative
zero_sized_not
zero_sized_not_equal
zero_sized_power
zero_sized_sign
zero_sized_sin
zero_sized_sinh
zero_sized_sqrt
zero_sized_subtract
zero_sized_tan
zero_sized_tanh
......@@ -65,6 +65,10 @@ if (NGRAPH_INTERPRETER_ENABLE)
set(ACTIVE_BACKEND_LIST ${ACTIVE_BACKEND_LIST} INTERPRETER)
endif()
if (NGRAPH_HYBRID_ENABLE)
set(SRC ${SRC} hybrid_backend.cpp)
endif()
if (NGRAPH_CPU_ENABLE)
list(APPEND SRC core_fusion.cpp quantize_cpu.cpp)
list(APPEND SRC backend_performance.cpp cpu_fusion.cpp cpu_test.cpp cpu_reshape_sinking.cpp)
......@@ -83,6 +87,10 @@ if (NGRAPH_INTELGPU_ENABLE)
set(ACTIVE_BACKEND_LIST ${ACTIVE_BACKEND_LIST} INTELGPU)
endif()
if (NGRAPH_HYBRID_ENABLE)
set(ACTIVE_BACKEND_LIST ${ACTIVE_BACKEND_LIST} HYBRID)
endif()
add_subdirectory(models)
add_subdirectory(files)
add_subdirectory(util)
......@@ -179,6 +187,10 @@ if (NGRAPH_INTERPRETER_ENABLE)
target_link_libraries(unit-test PRIVATE interpreter_backend)
endif()
if (NGRAPH_HYBRID_ENABLE)
target_link_libraries(unit-test PRIVATE hybrid_backend)
endif()
if (NGRAPH_GPU_ENABLE)
target_link_libraries(unit-test PRIVATE gpu_backend)
endif()
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment