Unverified Commit 8f102516 authored by Matthew Brookhart's avatar Matthew Brookhart Committed by GitHub

Merge branch 'master' into ayzhuang/in-place-concat

parents 3feb4264 982889f5
...@@ -40,33 +40,28 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg, ...@@ -40,33 +40,28 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
void op::Broadcast::validate_and_infer_types() void op::Broadcast::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic()) infer_shape();
for (auto axis : m_broadcast_axes)
{ {
return; NODE_VALIDATION_ASSERT(this, axis < m_shape.size())
<< "Broadcast axis index (" << axis << ") exceeds specified output shape rank "
<< "(broadcast axes: " << m_broadcast_axes << ", output shape: " << m_shape << ").";
} }
infer_shape(); Shape required_input_shape = m_shape;
Shape target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i) for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{ {
NODE_VALIDATION_ASSERT(this, *i < target_shape.size()) required_input_shape.erase(required_input_shape.begin() + *i);
<< "Broadcast axis index (" << *i << ") exceeds target shape rank "
<< "(broadcast axes: " << m_broadcast_axes << ", target shape: " << target_shape
<< ").";
target_shape.erase(target_shape.begin() + *i);
} }
// TODO(amprocte): We can probably have a more helpful error message here. // TODO(amprocte): We can probably have a more helpful error message here.
// There are two things that can go wrong, which are being picked up in // There are two things that can go wrong, which are being picked up in
// one fell swoop by this check: either the number of broadcast axes is not // one fell swoop by this check: either the number of broadcast axes is not
// enough (arg->get_shape().size() + broadcast_axes.size() != shape.size()) // enough, or there is a mismatch with one of the pre-broadcast axis lengths.
// or there is a mismatch with one of the pre-broadcast axis lengths NODE_VALIDATION_ASSERT(this, get_input_partial_shape(0).compatible(required_input_shape))
// (i.e. target_shape.size() == arg->get_shape.size() but there is some i << "Broadcast argument shape, specified output shape, and axes are incompatible "
// where target_shape[i] != arg->get_shape[i]). << "(argument shape: " << get_input_partial_shape(0) << ", output shape: " << m_shape
NODE_VALIDATION_ASSERT(this, target_shape == get_input_shape(0))
<< "Broadcast argument shape, target shape, and axes are incompatible "
<< "(argument shape: " << get_input_shape(0) << ", target shape: " << m_shape
<< ", broadcast axes: " << m_broadcast_axes << ")."; << ", broadcast axes: " << m_broadcast_axes << ").";
set_output_type(0, get_input_element_type(0), m_shape); set_output_type(0, get_input_element_type(0), m_shape);
......
...@@ -32,8 +32,6 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0, ...@@ -32,8 +32,6 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
, m_strides(strides) , m_strides(strides)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
check_args();
} }
op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0, op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
...@@ -46,69 +44,86 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0, ...@@ -46,69 +44,86 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
, m_strides(Strides(lower_bounds.size(), 1)) , m_strides(Strides(lower_bounds.size(), 1))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
check_args();
} }
void op::ReplaceSlice::check_args() void op::ReplaceSlice::validate_and_infer_types()
{ {
auto& input_0 = get_inputs().at(0); // An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
auto& input_0_shape = input_0.get_shape(); // construct the default value.
auto& input_0_element_type = input_0.get_element_type(); if (m_strides.size() == 0)
{
auto& input_1 = get_inputs().at(1); m_strides = Strides(m_lower_bounds.size(), 1);
auto& input_1_shape = input_1.get_shape(); }
auto& input_1_element_type = input_1.get_element_type();
NODE_VALIDATION_ASSERT(this, input_0_shape.size() == input_1_shape.size()) const PartialShape& arg0_shape = get_input_partial_shape(0);
<< "Argument ranks do not match (arg0 shape: " << input_0_shape const PartialShape& arg1_shape = get_input_partial_shape(1);
<< ", arg1 shape: " << input_1_shape << ")."; Dimension merged_args_rank;
NODE_VALIDATION_ASSERT(this, input_0_element_type == input_1_element_type) NODE_VALIDATION_ASSERT(this,
<< "Argument element types do not match (arg0 element type: " << input_0_element_type Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank()))
<< ", arg1 element type: " << input_1_element_type << ")."; << "Argument ranks do not match (arg0 shape: " << arg0_shape
<< ", arg1 shape: " << arg1_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds.size() == input_0_shape.size()) element::Type arg0_et = get_input_element_type(0);
<< "Rank of lower bounds (" << m_lower_bounds.size() << ") does not match rank " element::Type arg1_et = get_input_element_type(1);
<< "of argument (" << input_0_shape.size() << ") (lower bounds: " << m_lower_bounds element::Type merged_args_et;
<< ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, m_upper_bounds.size() == input_0_shape.size()) NODE_VALIDATION_ASSERT(this, element::Type::merge(merged_args_et, arg0_et, arg1_et))
<< "Rank of upper bounds (" << m_upper_bounds.size() << ") does not match rank " << "Argument element types do not match (arg0 element type: " << arg0_et
<< "of argument (" << input_0_shape.size() << ") (upper bounds: " << m_upper_bounds << ", arg1 element type: " << arg1_et << ").";
<< ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, m_strides.size() == input_0_shape.size()) NODE_VALIDATION_ASSERT(this,
<< "Rank of strides (" << m_strides.size() << ") does not match rank " m_lower_bounds.size() == m_upper_bounds.size() &&
<< "of argument (" << input_0_shape.size() << ") (strides: " << m_strides m_lower_bounds.size() == m_strides.size())
<< ", argument shape: " << input_0_shape << ")."; << "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
Shape slice_shape; size_t output_rank = m_upper_bounds.size();
for (size_t i = 0; i < input_0_shape.size(); i++) for (size_t i = 0; i < output_rank; i++)
{ {
NODE_VALIDATION_ASSERT(this, m_upper_bounds[i] <= input_0_shape[i])
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i]) NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i << "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ")."; << " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ")."; << " (strides: " << m_strides << ").";
}
size_t slice_axis_size = m_upper_bounds[i] - m_lower_bounds[i]; NODE_VALIDATION_ASSERT(this,
slice_axis_size = merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank)
slice_axis_size / m_strides[i] + ((slice_axis_size % m_strides[i] == 0) ? 0 : 1); << "Argument ranks do not match the rank of the lower bounds (" << m_lower_bounds
slice_shape.push_back(slice_axis_size); << "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
std::vector<Dimension> sliced_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(arg0_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << arg0_shape << ").";
size_t sliced_dim = m_upper_bounds[i] - m_lower_bounds[i];
sliced_dim = sliced_dim / m_strides[i] + ((sliced_dim % m_strides[i] == 0) ? 0 : 1);
sliced_dims[i] = sliced_dim;
} }
NODE_VALIDATION_ASSERT(this, input_1_shape == slice_shape) PartialShape slice_shape{sliced_dims};
<< "Shape of replacement tensor (" << input_1_shape << ") does not match the slice shape "
NODE_VALIDATION_ASSERT(this, arg1_shape.compatible(slice_shape))
<< "Shape of replacement tensor (" << arg1_shape << ") does not match the slice shape "
<< "(" << slice_shape << ")."; << "(" << slice_shape << ").";
set_output_type(0, input_0_element_type, input_0_shape); // Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
PartialShape result_shape =
(arg0_shape.rank().is_static())
? arg0_shape
: PartialShape(std::vector<Dimension>(output_rank, Dimension::dynamic()));
set_output_type(0, merged_args_et, result_shape);
} }
shared_ptr<Node> op::ReplaceSlice::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ReplaceSlice::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -88,11 +88,11 @@ namespace ngraph ...@@ -88,11 +88,11 @@ namespace ngraph
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
void check_args(); void validate_and_infer_types() override;
const Coordinate m_lower_bounds; Coordinate m_lower_bounds;
const Coordinate m_upper_bounds; Coordinate m_upper_bounds;
const Strides m_strides; Strides m_strides;
}; };
} }
} }
...@@ -32,27 +32,30 @@ op::Select::Select(const shared_ptr<Node>& arg0, ...@@ -32,27 +32,30 @@ op::Select::Select(const shared_ptr<Node>& arg0,
: Op("Select", check_single_output_args({arg0, arg1, arg2})) : Op("Select", check_single_output_args({arg0, arg1, arg2}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
}
auto& input_0 = get_inputs().at(0); void op::Select::validate_and_infer_types()
auto& input_1 = get_inputs().at(1); {
auto& input_2 = get_inputs().at(2); NODE_VALIDATION_ASSERT(this,
get_input_element_type(0).is_dynamic() ||
NODE_VALIDATION_ASSERT(this, input_0.get_element_type() == element::boolean) get_input_element_type(0) == element::boolean)
<< "Argument 0 does not have boolean element type (element type: " << "Argument 0 does not have boolean element type (element type: "
<< input_0.get_element_type() << ")."; << get_input_element_type(0) << ").";
NODE_VALIDATION_ASSERT(this, PartialShape result_shape = get_input_partial_shape(0);
input_0.get_shape() == input_1.get_shape() &&
input_0.get_shape() == input_2.get_shape()) NODE_VALIDATION_ASSERT(this, PartialShape::merge_into(result_shape, get_input_partial_shape(1)))
<< "Arguments do not all have the same shape (arg0 shape: " << input_0.get_shape() << "Argument shapes are inconsistent.";
<< ", arg1 shape: " << input_1.get_shape() << ", arg2 shape: " << input_2.get_shape() NODE_VALIDATION_ASSERT(this, PartialShape::merge_into(result_shape, get_input_partial_shape(2)))
<< ")."; << "Argument shapes are inconsistent.";
element::Type result_et;
NODE_VALIDATION_ASSERT(this, input_1.get_element_type() == input_2.get_element_type()) NODE_VALIDATION_ASSERT(
<< "Arguments 1 and 2 do not have the same element type (arg1 type: " this, element::Type::merge(result_et, get_input_element_type(1), get_input_element_type(2)))
<< input_1.get_element_type() << ", arg2 type: " << input_2.get_element_type() << ")."; << "Argument 1 and 2 element types are inconsistent.";
set_output_type(0, input_1.get_element_type(), input_1.get_shape()); set_output_type(0, result_et, result_shape);
} }
shared_ptr<Node> op::Select::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Select::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -53,6 +53,7 @@ namespace ngraph ...@@ -53,6 +53,7 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
protected: protected:
void validate_and_infer_types() override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
}; };
......
...@@ -44,55 +44,55 @@ op::Slice::Slice(const shared_ptr<Node>& arg, ...@@ -44,55 +44,55 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
void op::Slice::validate_and_infer_types() void op::Slice::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic()) // An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
{ // construct the default value.
return; if (m_strides.size() == 0)
}
if (0 == m_strides.size())
{ {
m_strides = Strides(m_lower_bounds.size(), 1); m_strides = Strides(m_lower_bounds.size(), 1);
} }
auto& input = get_inputs().at(0);
auto& input_shape = input.get_shape();
NODE_VALIDATION_ASSERT(this, m_lower_bounds.size() == input_shape.size())
<< "Rank of lower bounds (" << m_lower_bounds.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (lower bounds: " << m_lower_bounds
<< ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this, m_upper_bounds.size() == input_shape.size()) NODE_VALIDATION_ASSERT(this,
<< "Rank of upper bounds (" << m_upper_bounds.size() << ") does not match rank " m_lower_bounds.size() == m_upper_bounds.size() &&
<< "of argument (" << input_shape.size() << ") (upper bounds: " << m_upper_bounds m_lower_bounds.size() == m_strides.size())
<< ", argument shape: " << input_shape << ")."; << "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
NODE_VALIDATION_ASSERT(this, m_strides.size() == input_shape.size()) size_t output_rank = m_upper_bounds.size();
<< "Rank of strides (" << m_strides.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (strides: " << m_strides
<< ", argument shape: " << input_shape << ").";
Shape result_shape; for (size_t i = 0; i < output_rank; i++)
for (size_t i = 0; i < input_shape.size(); i++)
{ {
NODE_VALIDATION_ASSERT(this, m_upper_bounds[i] <= input_shape[i])
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i]) NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i << "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ")."; << " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ")."; << " (strides: " << m_strides << ").";
}
const PartialShape& input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank();
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || size_t(input_rank) == output_rank)
<< "Input rank does not match the rank of the lower bounds (" << m_lower_bounds
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
std::vector<Dimension> result_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
input_rank.is_dynamic() || input_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(input_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ").";
size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i]; size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
result_axis_size = result_axis_size =
result_axis_size / m_strides[i] + ((result_axis_size % m_strides[i] == 0) ? 0 : 1); result_axis_size / m_strides[i] + ((result_axis_size % m_strides[i] == 0) ? 0 : 1);
result_shape.push_back(result_axis_size); result_dims[i] = result_axis_size;
} }
set_output_type(0, input.get_element_type(), result_shape); set_output_type(0, get_input_element_type(0), PartialShape{result_dims});
} }
shared_ptr<Node> op::Slice::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Slice::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -81,6 +81,7 @@ set(SRC ...@@ -81,6 +81,7 @@ set(SRC
op/batch_norm_relu.cpp op/batch_norm_relu.cpp
op/bounded_relu.cpp op/bounded_relu.cpp
op/group_conv.cpp op/group_conv.cpp
op/halide_op.cpp
op/conv_bias.cpp op/conv_bias.cpp
op/conv_relu.cpp op/conv_relu.cpp
op/convert_layout.cpp op/convert_layout.cpp
...@@ -115,6 +116,14 @@ if (NOT NGRAPH_DEX_ONLY) ...@@ -115,6 +116,14 @@ if (NOT NGRAPH_DEX_ONLY)
) )
endif() endif()
if (NGRAPH_HALIDE)
set(SRC
${SRC}
builder/halide_op.cpp
pass/halide_subgraph_extraction.cpp
)
endif()
if (NGRAPH_TBB_ENABLE) if (NGRAPH_TBB_ENABLE)
include(${TBB_ROOT}/cmake/TBBBuild.cmake) include(${TBB_ROOT}/cmake/TBBBuild.cmake)
tbb_build(TBB_ROOT ${TBB_ROOT} MAKE_ARGS tbb_build_dir=${CMAKE_CURRENT_BINARY_DIR}/tbb_build tbb_build(TBB_ROOT ${TBB_ROOT} MAKE_ARGS tbb_build_dir=${CMAKE_CURRENT_BINARY_DIR}/tbb_build
...@@ -152,6 +161,12 @@ if (NGRAPH_CPU_ENABLE) ...@@ -152,6 +161,12 @@ if (NGRAPH_CPU_ENABLE)
if (NGRAPH_DEX_ONLY) if (NGRAPH_DEX_ONLY)
target_compile_definitions(cpu_backend PRIVATE "NGRAPH_DEX_ONLY") target_compile_definitions(cpu_backend PRIVATE "NGRAPH_DEX_ONLY")
endif() endif()
if (NGRAPH_HALIDE)
target_compile_definitions(cpu_backend PRIVATE "NGRAPH_HALIDE")
ExternalProject_Get_Property(ext_halide BINARY_DIR)
target_include_directories(cpu_backend SYSTEM PRIVATE ${BINARY_DIR}/include)
target_link_libraries(cpu_backend PRIVATE ${BINARY_DIR}/lib/libHalide.so)
endif()
if(OPENMP_FOUND) if(OPENMP_FOUND)
target_compile_options(cpu_backend PRIVATE "${OpenMP_CXX_FLAGS}") target_compile_options(cpu_backend PRIVATE "${OpenMP_CXX_FLAGS}")
......
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <Halide.h>
#include <HalideBuffer.h>
#include <functional>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/op/add.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp"
using namespace std;
using namespace ngraph;
#define TI(x) type_index(typeid(x))
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace halide
{
static const std::unordered_map<std::type_index,
std::function<Halide::Func(vector<Halide::Func>)>>
generators{{TI(ngraph::op::Add),
[](vector<Halide::Func> in) {
Halide::Var x;
Halide::Func func;
func(x) = in[0](x) + in[1](x);
return func;
}},
{TI(ngraph::op::Multiply),
[](vector<Halide::Func> in) {
Halide::Var x;
Halide::Func func;
func(x) = in[0](x) * in[1](x);
return func;
}},
{TI(ngraph::op::Relu), [](vector<Halide::Func> in) {
Halide::Var x;
Halide::Func func;
func(x) = Halide::max(in[0](x), 0);
return func;
}}};
}
template <>
void Builder::BUILDER_DECL(ngraph::runtime::cpu::op::HalideOp)
{
const ngraph::runtime::cpu::op::HalideOp* hs =
static_cast<const ngraph::runtime::cpu::op::HalideOp*>(node);
auto& halide_functions = external_function->get_halide_functions();
auto& subgraph_params = external_function->get_subgraph_params();
auto& subgraph_param_sizes = external_function->get_subgraph_param_sizes();
auto& subgraph_param_ptrs = external_function->get_subgraph_param_ptrs();
for (const auto& op : hs->get_ops())
{
if (!halide::generators.count(TI(*op)))
{
throw ngraph_error("Invalid op in halide subgraph");
}
vector<Halide::Func> inputs;
for (const auto& input : op->get_inputs())
{
auto tensor_name = input.get_output().get_tensor_ptr()->get_name();
if (halide_functions.count(tensor_name))
{
inputs.emplace_back(halide_functions[tensor_name]);
}
else
{
subgraph_params[tensor_name] = Halide::ImageParam(Halide::Float(32), 1);
subgraph_param_sizes[tensor_name] =
shape_size(input.get_output().get_tensor_ptr()->get_shape());
subgraph_param_ptrs.emplace(
tensor_name, external_function->get_tensor_data(tensor_name));
inputs.emplace_back(subgraph_params[tensor_name]);
}
}
halide_functions[op->get_output_tensor_ptr()->get_name()] =
halide::generators.at(TI(*op))(inputs);
}
auto out_tensor_name = hs->get_ops().back()->get_output_tensor_ptr()->get_name();
auto& functors = external_function->get_functors();
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto& terminal_func = halide_functions[out_tensor_name];
auto out_size = out[0].get_size();
auto functor = [&, out_size](CPURuntimeContext* ctx) {
for (auto& param : subgraph_params)
{
Halide::Buffer<float> param_buffer(
static_cast<float*>(subgraph_param_ptrs.at(param.first).get()),
subgraph_param_sizes.at(param.first));
param.second.set(param_buffer);
}
Halide::Buffer<float> out_buffer(static_cast<float*>(out_tensor), out_size);
terminal_func.realize(out_buffer);
};
functors.emplace_back(functor);
}
}
}
}
...@@ -98,6 +98,7 @@ ...@@ -98,6 +98,7 @@
#include "ngraph/runtime/cpu/kernel/tan.hpp" #include "ngraph/runtime/cpu/kernel/tan.hpp"
#include "ngraph/runtime/cpu/kernel/tanh.hpp" #include "ngraph/runtime/cpu/kernel/tanh.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp" #include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -367,7 +368,9 @@ namespace ngraph ...@@ -367,7 +368,9 @@ namespace ngraph
static BuildOpMap build_dispatcher{ static BuildOpMap build_dispatcher{
{TI(ngraph::op::Parameter), &runtime::cpu::Builder::nop}, {TI(ngraph::op::Parameter), &runtime::cpu::Builder::nop},
{TI(ngraph::runtime::cpu::op::ConvertLayout), {TI(ngraph::runtime::cpu::op::ConvertLayout),
&runtime::cpu::Builder::build<ngraph::runtime::cpu::op::ConvertLayout>}}; &runtime::cpu::Builder::build<ngraph::runtime::cpu::op::ConvertLayout>},
{TI(ngraph::runtime::cpu::op::HalideOp),
&runtime::cpu::Builder::build<ngraph::runtime::cpu::op::HalideOp>}};
return build_dispatcher; return build_dispatcher;
} }
......
...@@ -170,6 +170,7 @@ ...@@ -170,6 +170,7 @@
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp" #include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp" #include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp" #include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp"
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED
#include "ngraph/op/allreduce.hpp" #include "ngraph/op/allreduce.hpp"
...@@ -1023,6 +1024,10 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma ...@@ -1023,6 +1024,10 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
// pass_manager.register_pass<runtime::cpu::pass::CPUHorizontalFusion>(); // pass_manager.register_pass<runtime::cpu::pass::CPUHorizontalFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>(); pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>();
#if defined(NGRAPH_HALIDE)
pass_manager.register_pass<ngraph::runtime::cpu::pass::HalideSubgraphExtraction>();
#endif
NodeVector nv_cwi; // We dont need CPUWorkspaceInsertion to return list of indices NodeVector nv_cwi; // We dont need CPUWorkspaceInsertion to return list of indices
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi, false); pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi, false);
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this); pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
......
...@@ -27,6 +27,10 @@ ...@@ -27,6 +27,10 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#if defined(NGRAPH_HALIDE)
#include <Halide.h>
#endif
#if !defined(NGRAPH_DEX_ONLY) #if !defined(NGRAPH_DEX_ONLY)
#include "ngraph/codegen/code_writer.hpp" #include "ngraph/codegen/code_writer.hpp"
...@@ -135,6 +139,26 @@ namespace ngraph ...@@ -135,6 +139,26 @@ namespace ngraph
const std::string& directory, const std::string& directory,
const std::string& filename); const std::string& filename);
#if defined(NGRAPH_HALIDE)
std::unordered_map<std::string, Halide::Func>& get_halide_functions()
{
return halide_functions;
}
std::unordered_map<std::string, Halide::ImageParam>& get_subgraph_params()
{
return subgraph_params;
}
std::unordered_map<std::string, int>& get_subgraph_param_sizes()
{
return subgraph_param_sizes;
}
std::unordered_map<std::string, std::reference_wrapper<void*>>&
get_subgraph_param_ptrs()
{
return subgraph_param_ptrs;
}
#endif
protected: protected:
void build(); void build();
...@@ -240,6 +264,13 @@ namespace ngraph ...@@ -240,6 +264,13 @@ namespace ngraph
std::unordered_map<std::string, std::shared_ptr<CPU_ExternalFunction>> callees; std::unordered_map<std::string, std::shared_ptr<CPU_ExternalFunction>> callees;
bool m_is_built; bool m_is_built;
bool m_direct_execution; bool m_direct_execution;
#if defined(NGRAPH_HALIDE)
std::unordered_map<std::string, Halide::Func> halide_functions;
std::unordered_map<std::string, Halide::ImageParam> subgraph_params;
std::unordered_map<std::string, int> subgraph_param_sizes;
std::unordered_map<std::string, std::reference_wrapper<void*>> subgraph_param_ptrs;
#endif
}; };
} }
} }
......
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/cpu/op/halide_op.hpp"
using namespace std;
using namespace ngraph;
shared_ptr<Node> runtime::cpu::op::HalideOp::copy_with_new_args(const NodeVector& new_args) const
{
return make_shared<HalideOp>(new_args, ops, output_type, output_shape);
}
runtime::cpu::op::HalideOp::HalideOp(const NodeVector& args,
const std::list<std::shared_ptr<Node>>& ops,
const element::Type& out_type,
const Shape& out_shape)
: Op("HalideOp", check_single_output_args(args))
, ops(ops)
, output_type(out_type)
, output_shape(out_shape)
{
constructor_validate_and_infer_types();
}
void runtime::cpu::op::HalideOp::validate_and_infer_types()
{
set_output_type(0, output_type, output_shape);
}
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <list>
#include <vector>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace op
{
class HalideOp : public ngraph::op::Op
{
public:
HalideOp(const NodeVector& args,
const std::list<std::shared_ptr<Node>>& ops,
const element::Type& out_type,
const Shape& out_shape);
virtual void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const std::list<std::shared_ptr<Node>>& get_ops() const { return ops; }
private:
std::list<std::shared_ptr<Node>> ops;
element::Type output_type;
Shape output_shape;
};
}
}
}
}
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <iostream>
#include <list>
#include <typeindex>
#include <typeinfo>
#include <unordered_set>
#include "ngraph/op/add.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp"
#include "ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp"
using namespace std;
using namespace ngraph;
#define TI(x) type_index(typeid(x))
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace halide
{
static const std::unordered_set<std::type_index> whitelist{
TI(ngraph::op::Add), TI(ngraph::op::Multiply), TI(ngraph::op::Relu)};
static const std::unordered_set<std::type_index> skiplist{TI(ngraph::op::Parameter),
TI(ngraph::op::Result)};
}
}
}
}
// Support for multiple results, multiple outputs and getoutputelement, and multiple subgraphs in a single
// pipeline is not implemented since this should go away in favor of the "hybrid" transformer approach of
// carving out subgraphs in core ngraph
bool runtime::cpu::pass::HalideSubgraphExtraction::run_on_function(
std::shared_ptr<ngraph::Function> function)
{
list<shared_ptr<Node>> worklist;
auto results = function->get_results();
// Artificial limitation
if (results.size() > 1)
{
return false;
}
if (function->get_result()->get_element_type() != element::f32)
{
return false;
}
for (const auto& result : results)
{
worklist.emplace_back(result);
}
unordered_set<shared_ptr<Node>> ops;
list<shared_ptr<Node>> ordered_ops;
while (!worklist.empty())
{
const auto& node = worklist.front();
if (!halide::skiplist.count(TI(*node)))
{
if (halide::whitelist.count(TI(*node)))
{
ops.emplace(node);
ordered_ops.emplace_back(node);
}
else
{
break;
}
}
const auto& args = node->get_arguments();
for (const auto& arg : args)
{
worklist.emplace_back(arg);
}
worklist.pop_front();
}
NodeVector liveins;
for (const auto& op : ops)
{
const auto& args = op->get_arguments();
for (const auto& arg : args)
{
if (!ops.count(arg))
{
liveins.emplace_back(arg);
}
}
}
ordered_ops.reverse();
auto subgraph = make_shared<cpu::op::HalideOp>(liveins,
ordered_ops,
function->get_result()->get_element_type(),
function->get_result()->get_shape());
replace_node(function->get_result()->get_argument(0), subgraph);
return true;
}
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class HalideSubgraphExtraction : public ngraph::pass::FunctionPass
{
public:
HalideSubgraphExtraction() {}
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
};
}
}
}
}
...@@ -68,6 +68,9 @@ endif() ...@@ -68,6 +68,9 @@ endif()
if (NGRAPH_CPU_ENABLE) if (NGRAPH_CPU_ENABLE)
list(APPEND SRC core_fusion.cpp quantize_cpu.cpp) list(APPEND SRC core_fusion.cpp quantize_cpu.cpp)
list(APPEND SRC backend_performance.cpp cpu_fusion.cpp cpu_test.cpp cpu_reshape_sinking.cpp) list(APPEND SRC backend_performance.cpp cpu_fusion.cpp cpu_test.cpp cpu_reshape_sinking.cpp)
if (NGRAPH_HALIDE)
list(APPEND SRC halide.cpp)
endif()
set(ACTIVE_BACKEND_LIST ${ACTIVE_BACKEND_LIST} CPU) set(ACTIVE_BACKEND_LIST ${ACTIVE_BACKEND_LIST} CPU)
endif() endif()
......
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/util.hpp"
#include "util/all_close.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(halide, halide_subgraph)
{
Shape shape{8};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto D = make_shared<op::Parameter>(element::f32, shape);
auto relu = make_shared<op::Relu>((A + B) * C);
auto f = make_shared<Function>(relu + D, op::ParameterVector{A, B, C, D});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> c = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> d = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape);
vector<float> data{-1, 4, -2, 5, 1, 5, 7, 9};
copy_data(a, data);
copy_data(b, data);
copy_data(c, data);
copy_data(d, data);
vector<float> expected{1, 36, 6, 55, 3, 55, 105, 171};
backend->call_with_validate(f, {result}, {a, b, c, d});
EXPECT_TRUE(test::all_close(read_vector<float>(result), expected, 1.0e-4f, 1.0e-4f));
}
...@@ -25,12 +25,8 @@ using namespace ngraph; ...@@ -25,12 +25,8 @@ using namespace ngraph;
#define EXPECT_HAS_SUBSTRING(haystack, needle) \ #define EXPECT_HAS_SUBSTRING(haystack, needle) \
EXPECT_PRED_FORMAT2(testing::IsSubstring, needle, haystack) EXPECT_PRED_FORMAT2(testing::IsSubstring, needle, haystack)
//
// Tests for broadcast.
//
TEST(type_prop, broadcast_deduce) TEST(type_prop, broadcast_deduce)
{ {
// Deduce type
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4}); auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
Shape bc_shape{2, 3, 4}; Shape bc_shape{2, 3, 4};
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1}); auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
...@@ -38,6 +34,175 @@ TEST(type_prop, broadcast_deduce) ...@@ -38,6 +34,175 @@ TEST(type_prop, broadcast_deduce)
ASSERT_EQ(bc->get_shape(), bc_shape); ASSERT_EQ(bc->get_shape(), bc_shape);
} }
TEST(type_prop, broadcast_axes_oob)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = Shape{2, 3, 4};
try
{
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1, 3});
FAIL() << "Broadcast axis out of bounds not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Broadcast axis index (3) exceeds specified output shape rank");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_shape_mismatch_wrong_rank)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = Shape{2, 3, 4, 5};
try
{
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
FAIL() << "Output shape mismatch (wrong rank) not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Broadcast argument shape, specified output shape, and axes are incompatible");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_shape_mismatch_wrong_size)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = Shape{2, 3, 5};
try
{
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
FAIL() << "Output shape mismatch (wrong size) not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Broadcast argument shape, specified output shape, and axes are incompatible");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_partial_rank_dynamic_ok)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
Shape bc_shape{2, 3, 4};
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
ASSERT_EQ(bc->get_element_type(), element::f32);
ASSERT_EQ(bc->get_shape(), bc_shape);
}
TEST(type_prop, broadcast_partial_rank_dynamic_axes_oob)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto bc_shape = Shape{2, 3, 4};
try
{
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1, 3});
FAIL() << "Broadcast axis out of bounds not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Broadcast axis index (3) exceeds specified output shape rank");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_partial_rank_static_dynamic_ok)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
Shape bc_shape{2, 3, 4};
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
ASSERT_EQ(bc->get_element_type(), element::f32);
ASSERT_EQ(bc->get_shape(), bc_shape);
}
TEST(type_prop, broadcast_partial_rank_static_dynamic_axes_oob)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
auto bc_shape = Shape{2, 3, 4};
try
{
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1, 3});
FAIL() << "Broadcast axis out of bounds not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Broadcast axis index (3) exceeds specified output shape rank");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_partial_rank_static_dynamic_shape_mismatch_wrong_rank)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
auto bc_shape = Shape{2, 3, 4, 5};
try
{
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
FAIL() << "Output shape mismatch (wrong rank) not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Broadcast argument shape, specified output shape, and axes are incompatible");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_partial_rank_static_dynamic_shape_mismatch_wrong_size)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
auto bc_shape = Shape{2, 3, 5};
try
{
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
FAIL() << "Output shape mismatch (wrong size) not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Broadcast argument shape, specified output shape, and axes are incompatible");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batchnorm_rank_less_than_2) TEST(type_prop, batchnorm_rank_less_than_2)
{ {
auto dummy = make_shared<op::Parameter>(element::f32, Shape{1}); auto dummy = make_shared<op::Parameter>(element::f32, Shape{1});
...@@ -949,7 +1114,7 @@ TEST(type_prop, select_shape_mismatch_a) ...@@ -949,7 +1114,7 @@ TEST(type_prop, select_shape_mismatch_a)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Arguments do not all have the same shape")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
} }
catch (...) catch (...)
{ {
...@@ -970,7 +1135,7 @@ TEST(type_prop, select_shape_mismatch_b) ...@@ -970,7 +1135,7 @@ TEST(type_prop, select_shape_mismatch_b)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Arguments do not all have the same shape")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
} }
catch (...) catch (...)
{ {
...@@ -991,7 +1156,7 @@ TEST(type_prop, select_shape_mismatch_c) ...@@ -991,7 +1156,7 @@ TEST(type_prop, select_shape_mismatch_c)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Arguments do not all have the same shape")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
} }
catch (...) catch (...)
{ {
...@@ -1035,7 +1200,160 @@ TEST(type_prop, select_elem_mismatch_bc) ...@@ -1035,7 +1200,160 @@ TEST(type_prop, select_elem_mismatch_bc)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Arguments 1 and 2 do not have the same element type")); std::string("Argument 1 and 2 element types are inconsistent"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, select_partial_all_rank_dynamic)
{
auto param0 = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::f32);
ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, select_partial_all_rank_dynamic_arg0_et_dynamic_arg1_arg2_et_mismatch)
{
auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
try
{
auto sel = make_shared<op::Select>(param0, param1, param2);
FAIL() << "Did not detect mismatched element types for args 1 and 2 (element type-dynamic "
"arg0)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument 1 and 2 element types are inconsistent"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, select_partial_all_rank_dynamic_arg0_arg1_et_dynamic)
{
auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::f32);
ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, select_partial_all_rank_dynamic_arg0_arg2_et_dynamic)
{
auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::f32);
ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, select_partial_all_rank_dynamic_arg0_arg1_arg2_et_dynamic)
{
auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto param2 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::dynamic);
ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, select_partial_arg0_rank_dynamic_static_arg1_arg2_rank_dynamic_ok)
{
auto param0 =
make_shared<op::Parameter>(element::boolean, PartialShape{2, Dimension::dynamic(), 3});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::f32);
ASSERT_TRUE(
sel->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3}));
}
TEST(type_prop, select_partial_arg1_rank_dynamic_static_arg0_arg2_rank_dynamic_ok)
{
auto param0 = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
auto param1 =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3});
auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::f32);
ASSERT_TRUE(
sel->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3}));
}
TEST(type_prop, select_partial_arg2_rank_dynamic_static_arg0_arg1_rank_dynamic_ok)
{
auto param0 = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3});
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::f32);
ASSERT_TRUE(
sel->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3}));
}
TEST(type_prop, select_partial_all_rank_static_dynamic_ok)
{
auto param0 = make_shared<op::Parameter>(
element::boolean, PartialShape{2, Dimension::dynamic(), Dimension::dynamic()});
auto param1 = make_shared<op::Parameter>(
element::f32, PartialShape{Dimension::dynamic(), 8, Dimension::dynamic()});
auto param2 = make_shared<op::Parameter>(
element::f32, PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3});
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::f32);
ASSERT_TRUE(sel->get_output_partial_shape(0).is_static());
ASSERT_EQ(sel->get_output_shape(0), (Shape{2, 8, 3}));
}
TEST(type_prop, select_partial_all_rank_static_intransitive_incompatibility)
{
auto param0 = make_shared<op::Parameter>(
element::boolean, PartialShape{2, Dimension::dynamic(), Dimension::dynamic()});
auto param1 = make_shared<op::Parameter>(
element::f32, PartialShape{Dimension::dynamic(), 8, Dimension::dynamic()});
auto param2 =
make_shared<op::Parameter>(element::f32, PartialShape{3, Dimension::dynamic(), 3});
try
{
auto sel = make_shared<op::Select>(param0, param1, param2);
FAIL() << "Did not detect intransitive partial-shape incompatibility";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
} }
catch (...) catch (...)
{ {
...@@ -1654,7 +1972,9 @@ TEST(type_prop, slice_deduce_vector_invalid_strides) ...@@ -1654,7 +1972,9 @@ TEST(type_prop, slice_deduce_vector_invalid_strides)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), std::string("Rank of strides (2) does not match rank of argument (1)")); error.what(),
std::string("Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{7}) and strides (Strides{1, 2}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -1757,7 +2077,8 @@ TEST(type_prop, slice_deduce_matrix_lower_missing) ...@@ -1757,7 +2077,8 @@ TEST(type_prop, slice_deduce_matrix_lower_missing)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank of lower bounds (1) does not match rank of argument (2)")); std::string("Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{5, 5}) and strides (Strides{1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -1778,7 +2099,8 @@ TEST(type_prop, slice_deduce_matrix_upper_missing) ...@@ -1778,7 +2099,8 @@ TEST(type_prop, slice_deduce_matrix_upper_missing)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank of upper bounds (1) does not match rank of argument (2)")); std::string("Ranks of lower bounds (Coordinate{0, 0}), upper bounds "
"(Coordinate{5}) and strides (Strides{1, 1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -1797,9 +2119,10 @@ TEST(type_prop, slice_deduce_matrix_lower_extra) ...@@ -1797,9 +2119,10 @@ TEST(type_prop, slice_deduce_matrix_lower_extra)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Ranks of lower bounds (Coordinate{0, 0, "
std::string("Rank of lower bounds (3) does not match rank of argument (2)")); "0}), upper bounds (Coordinate{5, 5}) and "
"strides (Strides{1, 1, 1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -1817,10 +2140,169 @@ TEST(type_prop, slice_deduce_matrix_upper_extra) ...@@ -1817,10 +2140,169 @@ TEST(type_prop, slice_deduce_matrix_upper_extra)
FAIL() << "Extra upper bound coordinate not detected"; FAIL() << "Extra upper bound coordinate not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Ranks of lower bounds (Coordinate{0, 0}), "
"upper bounds (Coordinate{5, 5, 5}) and "
"strides (Strides{1, 1}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_partial_arg_input_rank_dynamic_attribs_ok)
{
PartialShape input_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
ASSERT_EQ(sl->get_element_type(), element::f32);
ASSERT_EQ(sl->get_shape(), (Shape{0, 1, 2, 2}));
}
TEST(type_prop, slice_partial_arg_rank_dynamic_attribs_rank_mismatch)
{
PartialShape input_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of lower-bounds/upper-bounds/strides ranks not detected (argument "
"rank-dynamic)";
}
catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank of upper bounds (3) does not match rank of argument (2)")); std::string("Ranks of lower bounds (Coordinate{1, 2, 3, 4}), upper bounds "
"(Coordinate{1, 3, 5}) and strides (Strides{1, 1, 1, 2}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_partial_arg_rank_dynamic_attribs_bounds_crossing)
{
PartialShape input_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 8};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Crossing lower/upper bounds not detected (argument rank-dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Lower bound for slice is greater than upper bound at axis 3 (lower "
"bounds: Coordinate{1, 2, 3, 8}, upper bounds: Coordinate{1, 3, 5, 7})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_partial_arg_rank_static_dynamic_ok)
{
PartialShape input_shape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
ASSERT_EQ(sl->get_element_type(), element::f32);
ASSERT_EQ(sl->get_shape(), (Shape{0, 1, 2, 2}));
}
TEST(type_prop, slice_partial_arg_rank_static_dynamic_some_dims_known_ok)
{
PartialShape input_shape{2, 4, 10, Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
ASSERT_EQ(sl->get_element_type(), element::f32);
ASSERT_EQ(sl->get_shape(), (Shape{0, 1, 2, 2}));
}
TEST(type_prop, slice_partial_arg_rank_static_dynamic_attribs_rank_mismatches_arg)
{
PartialShape input_shape{Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of attrib ranks with arg ranks not detected (argument rank-static "
"dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Input rank does not match the "
"rank of the lower bounds (Coordinate{1, 2, "
"3, 4}), upper bounds (Coordinate{1, 3, 5, "
"7}), and strides (Strides{1, 1, 1, 2})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_partial_arg_rank_static_dynamic_some_dims_known_upper_bounds_oob)
{
PartialShape input_shape{2, 2, 10, Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Upper bounds out of bounds not detected (argument rank-static dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Upper bound for slice at axis 1 is out of "
"range (upper bounds: Coordinate{1, 3, 5, "
"7}, argument shape: {2,2,10,?})"));
} }
catch (...) catch (...)
{ {
...@@ -1964,7 +2446,9 @@ TEST(type_prop, replace_slice_deduce_vector_invalid_strides) ...@@ -1964,7 +2446,9 @@ TEST(type_prop, replace_slice_deduce_vector_invalid_strides)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), std::string("Rank of strides (2) does not match rank of argument (1)")); error.what(),
std::string("Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{7}) and strides (Strides{1, 2}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -2027,9 +2511,10 @@ TEST(type_prop, replace_slice_deduce_matrix_slice_shape_mismatch) ...@@ -2027,9 +2511,10 @@ TEST(type_prop, replace_slice_deduce_matrix_slice_shape_mismatch)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Shape of replacement tensor (Shape{3, 6}) does not match " error.what(),
"the slice shape (Shape{4, 6})")); std::string(
"Shape of replacement tensor ({3,6}) does not match the slice shape ({4,6})"));
} }
catch (...) catch (...)
{ {
...@@ -2053,7 +2538,7 @@ TEST(type_prop, replace_slice_deduce_matrix_slice_shape_mismatch_strided) ...@@ -2053,7 +2538,7 @@ TEST(type_prop, replace_slice_deduce_matrix_slice_shape_mismatch_strided)
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string( std::string(
"Shape of replacement tensor (Shape{4, 6}) does not match the slice shape")); "Shape of replacement tensor ({4,6}) does not match the slice shape ({4,3})"));
} }
catch (...) catch (...)
{ {
...@@ -2163,7 +2648,8 @@ TEST(type_prop, replace_slice_deduce_matrix_lower_missing) ...@@ -2163,7 +2648,8 @@ TEST(type_prop, replace_slice_deduce_matrix_lower_missing)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank of lower bounds (1) does not match rank of argument (2)")); std::string("Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{5, 5}) and strides (Strides{1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -2185,7 +2671,8 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_missing) ...@@ -2185,7 +2671,8 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_missing)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank of upper bounds (1) does not match rank of argument (2)")); std::string("Ranks of lower bounds (Coordinate{0, 0}), upper bounds "
"(Coordinate{5}) and strides (Strides{1, 1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -2206,9 +2693,10 @@ TEST(type_prop, replace_slice_deduce_matrix_lower_extra) ...@@ -2206,9 +2693,10 @@ TEST(type_prop, replace_slice_deduce_matrix_lower_extra)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Ranks of lower bounds (Coordinate{0, 0, "
std::string("Rank of lower bounds (3) does not match rank of argument (2)")); "0}), upper bounds (Coordinate{5, 5}) and "
"strides (Strides{1, 1, 1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -2228,10 +2716,337 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_extra) ...@@ -2228,10 +2716,337 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_extra)
FAIL() << "Extra upper bound coordinate not detected"; FAIL() << "Extra upper bound coordinate not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Ranks of lower bounds (Coordinate{0, 0}), "
"upper bounds (Coordinate{5, 5, 5}) and "
"strides (Strides{1, 1}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, replace_slice_partial_input_rank_dynamic_replacement_rank_dynamic_attribs_ok)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(rsl->get_output_partial_shape(0).same_scheme(PartialShape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_dynamic_attribs_rank_mismatch)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of lower-bounds/upper-bounds/strides ranks not detected (argument "
"rank-dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Ranks of lower bounds (Coordinate{1, 2, 3, 4}), upper bounds "
"(Coordinate{1, 3, 5}) and strides (Strides{1, 1, 1, 2}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_dynamic_attribs_bounds_crossing)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 8};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Crossing lower/upper bounds not detected (argument rank-dynamic)";
}
catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank of upper bounds (3) does not match rank of argument (2)")); std::string("Lower bound for slice is greater than upper bound at axis 3 (lower "
"bounds: Coordinate{1, 2, 3, 8}, upper bounds: Coordinate{1, 3, 5, 7})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, replace_slice_partial_input_rank_static_dynamic_replacement_rank_dynamic_ok)
{
PartialShape input_shape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(rsl->get_output_partial_shape(0).same_scheme(PartialShape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop,
replace_slice_partial_input_rank_static_dynamic_some_dims_known_replacement_rank_dynamic_ok)
{
PartialShape input_shape{2, 4, 10, Dimension::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(
rsl->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 10, Dimension::dynamic()}));
}
TEST(
type_prop,
replace_slice_partial_input_rank_static_dynamic_replacement_rank_dynamic_attribs_rank_mismatches_input)
{
PartialShape input_shape{Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of attrib ranks with arg ranks not detected (argument rank-static "
"dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument ranks do not match the rank of the lower bounds "
"(Coordinate{1, 2, 3, 4}), upper bounds (Coordinate{1, 3, "
"5, 7}), and strides (Strides{1, 1, 1, 2})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
replace_slice_partial_input_rank_static_dynamic_some_dims_known_replacement_rank_dynamic_upper_bounds_oob)
{
PartialShape input_shape{2, 2, 10, Dimension::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Upper bounds out of bounds not detected (argument rank-static dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Upper bound for slice at axis 1 is out of "
"range (upper bounds: Coordinate{1, 3, 5, "
"7}, argument shape: {2,2,10,?})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_ok)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(rsl->get_output_partial_shape(0).same_scheme(PartialShape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_some_dims_known_ok)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{0, Dimension::dynamic(), Dimension::dynamic(), 2};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(rsl->get_output_partial_shape(0).same_scheme(PartialShape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(
type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_some_dims_known_attribs_mismatch_replacement_shape)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{1, Dimension::dynamic(), Dimension::dynamic(), 2};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of shape inferred from attributes with provided replacement shape not "
"detected (rank-dynamic/rank-static dynamic inputs)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shape of replacement tensor ({1,?,?,2}) does not match "
"the slice shape ({0,1,2,2})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_attribs_rank_mismatches_replacement)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of attrib ranks with arg ranks not detected (arguments "
"rank-dynamic/rank-static "
"dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument ranks do not match the rank of the lower bounds "
"(Coordinate{1, 2, 3, 4}), upper bounds (Coordinate{1, 3, "
"5, 7}), and strides (Strides{1, 1, 1, 2})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
replace_slice_partial_input_rank_static_dynamic_replacement_rank_static_dynamic_argument_ranks_mismatch)
{
PartialShape input_shape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
PartialShape replacement_shape{Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatching input/replacement ranks not detected (arguments both rank-static "
"dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument ranks do not match"));
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment