Unverified Commit 8f102516 authored by Matthew Brookhart's avatar Matthew Brookhart Committed by GitHub

Merge branch 'master' into ayzhuang/in-place-concat

parents 3feb4264 982889f5
......@@ -40,33 +40,28 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
void op::Broadcast::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
infer_shape();
for (auto axis : m_broadcast_axes)
{
return;
NODE_VALIDATION_ASSERT(this, axis < m_shape.size())
<< "Broadcast axis index (" << axis << ") exceeds specified output shape rank "
<< "(broadcast axes: " << m_broadcast_axes << ", output shape: " << m_shape << ").";
}
infer_shape();
Shape target_shape = m_shape;
Shape required_input_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
NODE_VALIDATION_ASSERT(this, *i < target_shape.size())
<< "Broadcast axis index (" << *i << ") exceeds target shape rank "
<< "(broadcast axes: " << m_broadcast_axes << ", target shape: " << target_shape
<< ").";
target_shape.erase(target_shape.begin() + *i);
required_input_shape.erase(required_input_shape.begin() + *i);
}
// TODO(amprocte): We can probably have a more helpful error message here.
// There are two things that can go wrong, which are being picked up in
// one fell swoop by this check: either the number of broadcast axes is not
// enough (arg->get_shape().size() + broadcast_axes.size() != shape.size())
// or there is a mismatch with one of the pre-broadcast axis lengths
// (i.e. target_shape.size() == arg->get_shape.size() but there is some i
// where target_shape[i] != arg->get_shape[i]).
NODE_VALIDATION_ASSERT(this, target_shape == get_input_shape(0))
<< "Broadcast argument shape, target shape, and axes are incompatible "
<< "(argument shape: " << get_input_shape(0) << ", target shape: " << m_shape
// enough, or there is a mismatch with one of the pre-broadcast axis lengths.
NODE_VALIDATION_ASSERT(this, get_input_partial_shape(0).compatible(required_input_shape))
<< "Broadcast argument shape, specified output shape, and axes are incompatible "
<< "(argument shape: " << get_input_partial_shape(0) << ", output shape: " << m_shape
<< ", broadcast axes: " << m_broadcast_axes << ").";
set_output_type(0, get_input_element_type(0), m_shape);
......
......@@ -32,8 +32,6 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
, m_strides(strides)
{
constructor_validate_and_infer_types();
check_args();
}
op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
......@@ -46,69 +44,86 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
, m_strides(Strides(lower_bounds.size(), 1))
{
constructor_validate_and_infer_types();
check_args();
}
void op::ReplaceSlice::check_args()
void op::ReplaceSlice::validate_and_infer_types()
{
auto& input_0 = get_inputs().at(0);
auto& input_0_shape = input_0.get_shape();
auto& input_0_element_type = input_0.get_element_type();
auto& input_1 = get_inputs().at(1);
auto& input_1_shape = input_1.get_shape();
auto& input_1_element_type = input_1.get_element_type();
// An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
// construct the default value.
if (m_strides.size() == 0)
{
m_strides = Strides(m_lower_bounds.size(), 1);
}
NODE_VALIDATION_ASSERT(this, input_0_shape.size() == input_1_shape.size())
<< "Argument ranks do not match (arg0 shape: " << input_0_shape
<< ", arg1 shape: " << input_1_shape << ").";
const PartialShape& arg0_shape = get_input_partial_shape(0);
const PartialShape& arg1_shape = get_input_partial_shape(1);
Dimension merged_args_rank;
NODE_VALIDATION_ASSERT(this, input_0_element_type == input_1_element_type)
<< "Argument element types do not match (arg0 element type: " << input_0_element_type
<< ", arg1 element type: " << input_1_element_type << ").";
NODE_VALIDATION_ASSERT(this,
Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank()))
<< "Argument ranks do not match (arg0 shape: " << arg0_shape
<< ", arg1 shape: " << arg1_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds.size() == input_0_shape.size())
<< "Rank of lower bounds (" << m_lower_bounds.size() << ") does not match rank "
<< "of argument (" << input_0_shape.size() << ") (lower bounds: " << m_lower_bounds
<< ", argument shape: " << input_0_shape << ").";
element::Type arg0_et = get_input_element_type(0);
element::Type arg1_et = get_input_element_type(1);
element::Type merged_args_et;
NODE_VALIDATION_ASSERT(this, m_upper_bounds.size() == input_0_shape.size())
<< "Rank of upper bounds (" << m_upper_bounds.size() << ") does not match rank "
<< "of argument (" << input_0_shape.size() << ") (upper bounds: " << m_upper_bounds
<< ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, element::Type::merge(merged_args_et, arg0_et, arg1_et))
<< "Argument element types do not match (arg0 element type: " << arg0_et
<< ", arg1 element type: " << arg1_et << ").";
NODE_VALIDATION_ASSERT(this, m_strides.size() == input_0_shape.size())
<< "Rank of strides (" << m_strides.size() << ") does not match rank "
<< "of argument (" << input_0_shape.size() << ") (strides: " << m_strides
<< ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size())
<< "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
Shape slice_shape;
size_t output_rank = m_upper_bounds.size();
for (size_t i = 0; i < input_0_shape.size(); i++)
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this, m_upper_bounds[i] <= input_0_shape[i])
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ").";
}
size_t slice_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
slice_axis_size =
slice_axis_size / m_strides[i] + ((slice_axis_size % m_strides[i] == 0) ? 0 : 1);
slice_shape.push_back(slice_axis_size);
NODE_VALIDATION_ASSERT(this,
merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank)
<< "Argument ranks do not match the rank of the lower bounds (" << m_lower_bounds
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
std::vector<Dimension> sliced_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(arg0_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << arg0_shape << ").";
size_t sliced_dim = m_upper_bounds[i] - m_lower_bounds[i];
sliced_dim = sliced_dim / m_strides[i] + ((sliced_dim % m_strides[i] == 0) ? 0 : 1);
sliced_dims[i] = sliced_dim;
}
NODE_VALIDATION_ASSERT(this, input_1_shape == slice_shape)
<< "Shape of replacement tensor (" << input_1_shape << ") does not match the slice shape "
PartialShape slice_shape{sliced_dims};
NODE_VALIDATION_ASSERT(this, arg1_shape.compatible(slice_shape))
<< "Shape of replacement tensor (" << arg1_shape << ") does not match the slice shape "
<< "(" << slice_shape << ").";
set_output_type(0, input_0_element_type, input_0_shape);
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
PartialShape result_shape =
(arg0_shape.rank().is_static())
? arg0_shape
: PartialShape(std::vector<Dimension>(output_rank, Dimension::dynamic()));
set_output_type(0, merged_args_et, result_shape);
}
shared_ptr<Node> op::ReplaceSlice::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -88,11 +88,11 @@ namespace ngraph
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
void check_args();
void validate_and_infer_types() override;
const Coordinate m_lower_bounds;
const Coordinate m_upper_bounds;
const Strides m_strides;
Coordinate m_lower_bounds;
Coordinate m_upper_bounds;
Strides m_strides;
};
}
}
......@@ -32,27 +32,30 @@ op::Select::Select(const shared_ptr<Node>& arg0,
: Op("Select", check_single_output_args({arg0, arg1, arg2}))
{
constructor_validate_and_infer_types();
}
auto& input_0 = get_inputs().at(0);
auto& input_1 = get_inputs().at(1);
auto& input_2 = get_inputs().at(2);
NODE_VALIDATION_ASSERT(this, input_0.get_element_type() == element::boolean)
void op::Select::validate_and_infer_types()
{
NODE_VALIDATION_ASSERT(this,
get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::boolean)
<< "Argument 0 does not have boolean element type (element type: "
<< input_0.get_element_type() << ").";
<< get_input_element_type(0) << ").";
NODE_VALIDATION_ASSERT(this,
input_0.get_shape() == input_1.get_shape() &&
input_0.get_shape() == input_2.get_shape())
<< "Arguments do not all have the same shape (arg0 shape: " << input_0.get_shape()
<< ", arg1 shape: " << input_1.get_shape() << ", arg2 shape: " << input_2.get_shape()
<< ").";
PartialShape result_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT(this, PartialShape::merge_into(result_shape, get_input_partial_shape(1)))
<< "Argument shapes are inconsistent.";
NODE_VALIDATION_ASSERT(this, PartialShape::merge_into(result_shape, get_input_partial_shape(2)))
<< "Argument shapes are inconsistent.";
element::Type result_et;
NODE_VALIDATION_ASSERT(this, input_1.get_element_type() == input_2.get_element_type())
<< "Arguments 1 and 2 do not have the same element type (arg1 type: "
<< input_1.get_element_type() << ", arg2 type: " << input_2.get_element_type() << ").";
NODE_VALIDATION_ASSERT(
this, element::Type::merge(result_et, get_input_element_type(1), get_input_element_type(2)))
<< "Argument 1 and 2 element types are inconsistent.";
set_output_type(0, input_1.get_element_type(), input_1.get_shape());
set_output_type(0, result_et, result_shape);
}
shared_ptr<Node> op::Select::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -53,6 +53,7 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override;
protected:
void validate_and_infer_types() override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
......
......@@ -44,55 +44,55 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
void op::Slice::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
if (0 == m_strides.size())
// An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
// construct the default value.
if (m_strides.size() == 0)
{
m_strides = Strides(m_lower_bounds.size(), 1);
}
auto& input = get_inputs().at(0);
auto& input_shape = input.get_shape();
NODE_VALIDATION_ASSERT(this, m_lower_bounds.size() == input_shape.size())
<< "Rank of lower bounds (" << m_lower_bounds.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (lower bounds: " << m_lower_bounds
<< ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this, m_upper_bounds.size() == input_shape.size())
<< "Rank of upper bounds (" << m_upper_bounds.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (upper bounds: " << m_upper_bounds
<< ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size())
<< "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
NODE_VALIDATION_ASSERT(this, m_strides.size() == input_shape.size())
<< "Rank of strides (" << m_strides.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (strides: " << m_strides
<< ", argument shape: " << input_shape << ").";
size_t output_rank = m_upper_bounds.size();
Shape result_shape;
for (size_t i = 0; i < input_shape.size(); i++)
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this, m_upper_bounds[i] <= input_shape[i])
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ").";
}
const PartialShape& input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank();
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || size_t(input_rank) == output_rank)
<< "Input rank does not match the rank of the lower bounds (" << m_lower_bounds
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
std::vector<Dimension> result_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
input_rank.is_dynamic() || input_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(input_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ").";
size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
result_axis_size =
result_axis_size / m_strides[i] + ((result_axis_size % m_strides[i] == 0) ? 0 : 1);
result_shape.push_back(result_axis_size);
result_dims[i] = result_axis_size;
}
set_output_type(0, input.get_element_type(), result_shape);
set_output_type(0, get_input_element_type(0), PartialShape{result_dims});
}
shared_ptr<Node> op::Slice::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -81,6 +81,7 @@ set(SRC
op/batch_norm_relu.cpp
op/bounded_relu.cpp
op/group_conv.cpp
op/halide_op.cpp
op/conv_bias.cpp
op/conv_relu.cpp
op/convert_layout.cpp
......@@ -115,6 +116,14 @@ if (NOT NGRAPH_DEX_ONLY)
)
endif()
if (NGRAPH_HALIDE)
set(SRC
${SRC}
builder/halide_op.cpp
pass/halide_subgraph_extraction.cpp
)
endif()
if (NGRAPH_TBB_ENABLE)
include(${TBB_ROOT}/cmake/TBBBuild.cmake)
tbb_build(TBB_ROOT ${TBB_ROOT} MAKE_ARGS tbb_build_dir=${CMAKE_CURRENT_BINARY_DIR}/tbb_build
......@@ -152,6 +161,12 @@ if (NGRAPH_CPU_ENABLE)
if (NGRAPH_DEX_ONLY)
target_compile_definitions(cpu_backend PRIVATE "NGRAPH_DEX_ONLY")
endif()
if (NGRAPH_HALIDE)
target_compile_definitions(cpu_backend PRIVATE "NGRAPH_HALIDE")
ExternalProject_Get_Property(ext_halide BINARY_DIR)
target_include_directories(cpu_backend SYSTEM PRIVATE ${BINARY_DIR}/include)
target_link_libraries(cpu_backend PRIVATE ${BINARY_DIR}/lib/libHalide.so)
endif()
if(OPENMP_FOUND)
target_compile_options(cpu_backend PRIVATE "${OpenMP_CXX_FLAGS}")
......
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <Halide.h>
#include <HalideBuffer.h>
#include <functional>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/op/add.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp"
using namespace std;
using namespace ngraph;
#define TI(x) type_index(typeid(x))
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace halide
{
static const std::unordered_map<std::type_index,
std::function<Halide::Func(vector<Halide::Func>)>>
generators{{TI(ngraph::op::Add),
[](vector<Halide::Func> in) {
Halide::Var x;
Halide::Func func;
func(x) = in[0](x) + in[1](x);
return func;
}},
{TI(ngraph::op::Multiply),
[](vector<Halide::Func> in) {
Halide::Var x;
Halide::Func func;
func(x) = in[0](x) * in[1](x);
return func;
}},
{TI(ngraph::op::Relu), [](vector<Halide::Func> in) {
Halide::Var x;
Halide::Func func;
func(x) = Halide::max(in[0](x), 0);
return func;
}}};
}
template <>
void Builder::BUILDER_DECL(ngraph::runtime::cpu::op::HalideOp)
{
const ngraph::runtime::cpu::op::HalideOp* hs =
static_cast<const ngraph::runtime::cpu::op::HalideOp*>(node);
auto& halide_functions = external_function->get_halide_functions();
auto& subgraph_params = external_function->get_subgraph_params();
auto& subgraph_param_sizes = external_function->get_subgraph_param_sizes();
auto& subgraph_param_ptrs = external_function->get_subgraph_param_ptrs();
for (const auto& op : hs->get_ops())
{
if (!halide::generators.count(TI(*op)))
{
throw ngraph_error("Invalid op in halide subgraph");
}
vector<Halide::Func> inputs;
for (const auto& input : op->get_inputs())
{
auto tensor_name = input.get_output().get_tensor_ptr()->get_name();
if (halide_functions.count(tensor_name))
{
inputs.emplace_back(halide_functions[tensor_name]);
}
else
{
subgraph_params[tensor_name] = Halide::ImageParam(Halide::Float(32), 1);
subgraph_param_sizes[tensor_name] =
shape_size(input.get_output().get_tensor_ptr()->get_shape());
subgraph_param_ptrs.emplace(
tensor_name, external_function->get_tensor_data(tensor_name));
inputs.emplace_back(subgraph_params[tensor_name]);
}
}
halide_functions[op->get_output_tensor_ptr()->get_name()] =
halide::generators.at(TI(*op))(inputs);
}
auto out_tensor_name = hs->get_ops().back()->get_output_tensor_ptr()->get_name();
auto& functors = external_function->get_functors();
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto& terminal_func = halide_functions[out_tensor_name];
auto out_size = out[0].get_size();
auto functor = [&, out_size](CPURuntimeContext* ctx) {
for (auto& param : subgraph_params)
{
Halide::Buffer<float> param_buffer(
static_cast<float*>(subgraph_param_ptrs.at(param.first).get()),
subgraph_param_sizes.at(param.first));
param.second.set(param_buffer);
}
Halide::Buffer<float> out_buffer(static_cast<float*>(out_tensor), out_size);
terminal_func.realize(out_buffer);
};
functors.emplace_back(functor);
}
}
}
}
......@@ -98,6 +98,7 @@
#include "ngraph/runtime/cpu/kernel/tan.hpp"
#include "ngraph/runtime/cpu/kernel/tanh.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
......@@ -367,7 +368,9 @@ namespace ngraph
static BuildOpMap build_dispatcher{
{TI(ngraph::op::Parameter), &runtime::cpu::Builder::nop},
{TI(ngraph::runtime::cpu::op::ConvertLayout),
&runtime::cpu::Builder::build<ngraph::runtime::cpu::op::ConvertLayout>}};
&runtime::cpu::Builder::build<ngraph::runtime::cpu::op::ConvertLayout>},
{TI(ngraph::runtime::cpu::op::HalideOp),
&runtime::cpu::Builder::build<ngraph::runtime::cpu::op::HalideOp>}};
return build_dispatcher;
}
......
......@@ -170,6 +170,7 @@
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/op/allreduce.hpp"
......@@ -1023,6 +1024,10 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
// pass_manager.register_pass<runtime::cpu::pass::CPUHorizontalFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>();
#if defined(NGRAPH_HALIDE)
pass_manager.register_pass<ngraph::runtime::cpu::pass::HalideSubgraphExtraction>();
#endif
NodeVector nv_cwi; // We dont need CPUWorkspaceInsertion to return list of indices
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi, false);
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
......
......@@ -27,6 +27,10 @@
#include <utility>
#include <vector>
#if defined(NGRAPH_HALIDE)
#include <Halide.h>
#endif
#if !defined(NGRAPH_DEX_ONLY)
#include "ngraph/codegen/code_writer.hpp"
......@@ -135,6 +139,26 @@ namespace ngraph
const std::string& directory,
const std::string& filename);
#if defined(NGRAPH_HALIDE)
std::unordered_map<std::string, Halide::Func>& get_halide_functions()
{
return halide_functions;
}
std::unordered_map<std::string, Halide::ImageParam>& get_subgraph_params()
{
return subgraph_params;
}
std::unordered_map<std::string, int>& get_subgraph_param_sizes()
{
return subgraph_param_sizes;
}
std::unordered_map<std::string, std::reference_wrapper<void*>>&
get_subgraph_param_ptrs()
{
return subgraph_param_ptrs;
}
#endif
protected:
void build();
......@@ -240,6 +264,13 @@ namespace ngraph
std::unordered_map<std::string, std::shared_ptr<CPU_ExternalFunction>> callees;
bool m_is_built;
bool m_direct_execution;
#if defined(NGRAPH_HALIDE)
std::unordered_map<std::string, Halide::Func> halide_functions;
std::unordered_map<std::string, Halide::ImageParam> subgraph_params;
std::unordered_map<std::string, int> subgraph_param_sizes;
std::unordered_map<std::string, std::reference_wrapper<void*>> subgraph_param_ptrs;
#endif
};
}
}
......
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/cpu/op/halide_op.hpp"
using namespace std;
using namespace ngraph;
shared_ptr<Node> runtime::cpu::op::HalideOp::copy_with_new_args(const NodeVector& new_args) const
{
return make_shared<HalideOp>(new_args, ops, output_type, output_shape);
}
runtime::cpu::op::HalideOp::HalideOp(const NodeVector& args,
const std::list<std::shared_ptr<Node>>& ops,
const element::Type& out_type,
const Shape& out_shape)
: Op("HalideOp", check_single_output_args(args))
, ops(ops)
, output_type(out_type)
, output_shape(out_shape)
{
constructor_validate_and_infer_types();
}
void runtime::cpu::op::HalideOp::validate_and_infer_types()
{
set_output_type(0, output_type, output_shape);
}
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <list>
#include <vector>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace op
{
class HalideOp : public ngraph::op::Op
{
public:
HalideOp(const NodeVector& args,
const std::list<std::shared_ptr<Node>>& ops,
const element::Type& out_type,
const Shape& out_shape);
virtual void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const std::list<std::shared_ptr<Node>>& get_ops() const { return ops; }
private:
std::list<std::shared_ptr<Node>> ops;
element::Type output_type;
Shape output_shape;
};
}
}
}
}
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <iostream>
#include <list>
#include <typeindex>
#include <typeinfo>
#include <unordered_set>
#include "ngraph/op/add.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp"
#include "ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp"
using namespace std;
using namespace ngraph;
#define TI(x) type_index(typeid(x))
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace halide
{
static const std::unordered_set<std::type_index> whitelist{
TI(ngraph::op::Add), TI(ngraph::op::Multiply), TI(ngraph::op::Relu)};
static const std::unordered_set<std::type_index> skiplist{TI(ngraph::op::Parameter),
TI(ngraph::op::Result)};
}
}
}
}
// Support for multiple results, multiple outputs and getoutputelement, and multiple subgraphs in a single
// pipeline is not implemented since this should go away in favor of the "hybrid" transformer approach of
// carving out subgraphs in core ngraph
bool runtime::cpu::pass::HalideSubgraphExtraction::run_on_function(
std::shared_ptr<ngraph::Function> function)
{
list<shared_ptr<Node>> worklist;
auto results = function->get_results();
// Artificial limitation
if (results.size() > 1)
{
return false;
}
if (function->get_result()->get_element_type() != element::f32)
{
return false;
}
for (const auto& result : results)
{
worklist.emplace_back(result);
}
unordered_set<shared_ptr<Node>> ops;
list<shared_ptr<Node>> ordered_ops;
while (!worklist.empty())
{
const auto& node = worklist.front();
if (!halide::skiplist.count(TI(*node)))
{
if (halide::whitelist.count(TI(*node)))
{
ops.emplace(node);
ordered_ops.emplace_back(node);
}
else
{
break;
}
}
const auto& args = node->get_arguments();
for (const auto& arg : args)
{
worklist.emplace_back(arg);
}
worklist.pop_front();
}
NodeVector liveins;
for (const auto& op : ops)
{
const auto& args = op->get_arguments();
for (const auto& arg : args)
{
if (!ops.count(arg))
{
liveins.emplace_back(arg);
}
}
}
ordered_ops.reverse();
auto subgraph = make_shared<cpu::op::HalideOp>(liveins,
ordered_ops,
function->get_result()->get_element_type(),
function->get_result()->get_shape());
replace_node(function->get_result()->get_argument(0), subgraph);
return true;
}
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class HalideSubgraphExtraction : public ngraph::pass::FunctionPass
{
public:
HalideSubgraphExtraction() {}
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
};
}
}
}
}
......@@ -68,6 +68,9 @@ endif()
if (NGRAPH_CPU_ENABLE)
list(APPEND SRC core_fusion.cpp quantize_cpu.cpp)
list(APPEND SRC backend_performance.cpp cpu_fusion.cpp cpu_test.cpp cpu_reshape_sinking.cpp)
if (NGRAPH_HALIDE)
list(APPEND SRC halide.cpp)
endif()
set(ACTIVE_BACKEND_LIST ${ACTIVE_BACKEND_LIST} CPU)
endif()
......
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/util.hpp"
#include "util/all_close.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(halide, halide_subgraph)
{
Shape shape{8};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto D = make_shared<op::Parameter>(element::f32, shape);
auto relu = make_shared<op::Relu>((A + B) * C);
auto f = make_shared<Function>(relu + D, op::ParameterVector{A, B, C, D});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> c = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> d = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape);
vector<float> data{-1, 4, -2, 5, 1, 5, 7, 9};
copy_data(a, data);
copy_data(b, data);
copy_data(c, data);
copy_data(d, data);
vector<float> expected{1, 36, 6, 55, 3, 55, 105, 171};
backend->call_with_validate(f, {result}, {a, b, c, d});
EXPECT_TRUE(test::all_close(read_vector<float>(result), expected, 1.0e-4f, 1.0e-4f));
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment