Commit 0be581c0 authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

GPU Padding - add support for custom pad value and interior padding (#860)

* * cuda_emitter::build_pad now utilizes pad_value.

* Added TypeInfo class for dispatching c-type information from the underlying ngraph element::Type.
  Adjusted test to use all_close when comparing floating point values (max_pool_2d_1channel_1image_overpadded).

* Refactored max_pool_1d into cuda_emitter so that numeric_limits<c_type>::lowest() could be used for initial max value.
Test max_pool_2d_1channel_1image_padded_negative_values now enabled and passes.

* Removed old function and switch to size_t to match ngraph.

* Added virtual dtor.

* Adding support for interior padding. All op::Pad functionality is now included.

* More info in runtime_error for checking of tensor dimensions. Removed commented code.
parent 8cb48d37
...@@ -283,6 +283,7 @@ endif() ...@@ -283,6 +283,7 @@ endif()
runtime/gpu/gpu_invoke.cpp runtime/gpu/gpu_invoke.cpp
runtime/gpu/cudnn_emitter.cpp runtime/gpu/cudnn_emitter.cpp
runtime/gpu/cuda_emitter.cpp runtime/gpu/cuda_emitter.cpp
runtime/gpu/type_info.cpp
) )
set_property(SOURCE codegen/compiler.cpp APPEND_STRING PROPERTY COMPILE_DEFINITIONS set_property(SOURCE codegen/compiler.cpp APPEND_STRING PROPERTY COMPILE_DEFINITIONS
"CUDA_HEADER_PATHS=\"${CUDA_INCLUDE_DIRS}\";") "CUDA_HEADER_PATHS=\"${CUDA_INCLUDE_DIRS}\";")
......
This diff is collapsed.
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <array> #include <array>
#include "ngraph/codegen/code_writer.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -40,10 +41,21 @@ namespace ngraph ...@@ -40,10 +41,21 @@ namespace ngraph
const Shape& output_shape, const Shape& output_shape,
const Shape& pad_below, const Shape& pad_below,
const Shape& pad_above, const Shape& pad_above,
const Shape& pad_interior); const Shape& pad_interior,
const std::string& pad_value = "");
size_t build_1d_max_pool(const GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
const Shape& input_shape,
const Shape& output_shape,
size_t window_width,
size_t window_stride);
private: private:
CUDAEmitter(GPUPrimitiveEmitter* emitter); CUDAEmitter(GPUPrimitiveEmitter* emitter);
void print_tensor_from_gpu(codegen::CodeWriter& writer,
const std::string& tensor_name,
const Shape& shape);
GPUPrimitiveEmitter* m_primitive_emitter; GPUPrimitiveEmitter* m_primitive_emitter;
}; };
......
...@@ -167,48 +167,6 @@ void runtime::gpu::CudaKernelBuilder::get_slice_op(codegen::CodeWriter& writer, ...@@ -167,48 +167,6 @@ void runtime::gpu::CudaKernelBuilder::get_slice_op(codegen::CodeWriter& writer,
writer.block_end(); writer.block_end();
} }
void runtime::gpu::CudaKernelBuilder::get_1d_max_pool(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types)
{
// assumes data is in NCW format
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1] << "* out, "
<< "int width, "
<< "int stride, "
<< "int input_size, "
<< "int output_size, "
<< "int n)\n";
writer.block_begin();
{
// index into output tensor
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if (tid < n)\n";
writer.block_begin();
{
// index into input tensor
writer << "size_t start = (tid / output_size) * input_size + (tid % output_size) * "
"stride;\n";
writer << data_types[0] << " max_val = 0;\n";
writer << "for (size_t i = start; i < start+width; i++)\n";
writer.block_begin();
{
writer << "const " << data_types[0] << " input = in[i];\n";
writer << "if (input > max_val)";
writer.block_begin();
{
writer << "max_val = input;\n";
}
writer.block_end();
}
writer.block_end();
writer << "out[tid] = max_val;\n";
}
writer.block_end();
}
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_device_helper( void runtime::gpu::CudaKernelBuilder::get_device_helper(
codegen::CodeWriter& writer, codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
......
...@@ -55,10 +55,6 @@ namespace ngraph ...@@ -55,10 +55,6 @@ namespace ngraph
const std::string& name, const std::string& name,
const std::array<std::string, 2>& data_types); const std::array<std::string, 2>& data_types);
static void get_1d_max_pool(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types);
static void get_device_helper(codegen::CodeWriter& writer, static void get_device_helper(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
const std::string& math_kernel, const std::string& math_kernel,
......
...@@ -124,42 +124,6 @@ namespace ngraph ...@@ -124,42 +124,6 @@ namespace ngraph
0)); // arguments 0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output. CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
} }
template <typename... Args>
void emit_1d_max_pool(GPURuntimeContext* ctx,
const std::string& name,
const std::array<std::string, 2>& data_types,
size_t count,
Args&&... args)
{
std::string name_signature = name + "_" + data_types[0] + "_" + data_types[1];
std::replace(name_signature.begin(), name_signature.end(), ' ', '_');
auto compiled_kernel = ctx->compiled_kernel_pool->get(name_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::get_1d_max_pool(writer, name_signature, data_types);
std::string kernel = writer.get_code();
compiled_kernel = ctx->compiled_kernel_pool->set(name_signature, kernel);
}
if (sizeof...(args))
{
std::vector<void*> args_list = {&args..., &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(count),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
&args_list[0],
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
}
} }
} }
} }
...@@ -96,6 +96,7 @@ ...@@ -96,6 +96,7 @@
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp" #include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp" #include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp" #include "ngraph/runtime/gpu/gpu_util.hpp"
#include "ngraph/runtime/gpu/type_info.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -1232,9 +1233,10 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -1232,9 +1233,10 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
output_shape, output_shape,
padding_below, padding_below,
padding_above, padding_above,
{}); padding_interior);
writer << "gpu::invoke_primitive(ctx, " << pad_index << ", "; writer << "gpu::invoke_primitive(ctx, " << pad_index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), "; writer << "std::vector<void*>{" << args[0].get_name() << ", "
<< args[1].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data() "; writer << "std::vector<void*>{" << out[0].get_name() << "}.data() ";
writer << ");\n"; writer << ");\n";
} }
...@@ -1252,10 +1254,12 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -1252,10 +1254,12 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
auto& result_shape = out[0].get_shape(); auto& result_shape = out[0].get_shape();
auto padding_below = max_pool->get_padding_below(); auto padding_below = max_pool->get_padding_below();
auto padding_above = max_pool->get_padding_above(); auto padding_above = max_pool->get_padding_above();
if (padding_below.size() != padding_above.size()) if (input_shape.size() < 3)
{ {
throw std::runtime_error( throw std::runtime_error(
"Padding below and above are of different dimension."); "MaxPool operation requested for a tensor of less than 3 dimensions. "
"Tensors should have at least one spatial dimension, dim(NC{d1...dN}) "
"<= 3");
} }
bool pad_required = false; bool pad_required = false;
...@@ -1276,8 +1280,9 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -1276,8 +1280,9 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
shape_size(shape_to_pool) * args[0].get_element_type().size(); shape_size(shape_to_pool) * args[0].get_element_type().size();
writer << "void* pad_buffer = " writer << "void* pad_buffer = "
<< "runtime::gpu::create_gpu_buffer(" << temp_size << ");\n"; << "runtime::gpu::create_gpu_buffer(" << temp_size << ");\n";
writer << "runtime::gpu::cuda_memset(pad_buffer, 0, " << temp_size
<< ");\n"; std::stringstream ss;
ss << TypeInfo::Get(args[0].get_element_type())->lowest();
auto pad_index = auto pad_index =
cuda_emitter->build_pad(external_function->ctx().get(), cuda_emitter->build_pad(external_function->ctx().get(),
...@@ -1286,7 +1291,8 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -1286,7 +1291,8 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
shape_to_pool, shape_to_pool,
padding_below, padding_below,
padding_above, padding_above,
/*padding_interior*/ {}); /*padding_interior*/ {},
ss.str());
writer << "gpu::invoke_primitive(ctx, " << pad_index << ", "; writer << "gpu::invoke_primitive(ctx, " << pad_index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), "; writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
...@@ -1308,40 +1314,40 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -1308,40 +1314,40 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
} }
} }
// 1d max pool if (input_shape.size() <= 5)
if (input_shape.size() == 3 || num_nontrivial_dims == 1)
{
// pre-compile cuda kernel
runtime::gpu::emit_1d_max_pool(external_function->ctx().get(),
max_pool->description(),
{{args[0].get_type(), out[0].get_type()}},
0);
// emit invocation of kernel
writer << "runtime::gpu::emit_1d_max_pool("
<< "ctx, "
<< "\"" << max_pool->description() << "\", "
<< "{\"" << args[0].get_type() << "\", \"" << out[0].get_type()
<< "\"}, " << out[0].get_size() << ", " << args[0].get_name() << ", "
<< out[0].get_name() << ", " << max_pool->get_window_shape()[0]
<< ", " << max_pool->get_window_movement_strides()[0] << ", "
<< input_shape.back() << ", " << result_shape.back() << ");\n";
}
// 2d and 3d max pool (NCHW)
else if (input_shape.size() == 4 || input_shape.size() == 5)
{ {
auto& cudnn_emitter = size_t max_pool_index = 0;
external_function->get_primitive_emitter()->get_cudnn_emitter(); // 1d max pool (NCW)
if ((input_shape.size() == 3 || num_nontrivial_dims == 1))
{
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
max_pool_index = cuda_emitter->build_1d_max_pool(
external_function->ctx().get(),
{{args[0].get_type(), out[0].get_type()}},
input_shape,
result_shape,
max_pool->get_window_shape().back(),
max_pool->get_window_movement_strides().back());
}
// 2d and 3d max pool (NCHW)
else if (input_shape.size() == 4 || input_shape.size() == 5)
{
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
auto max_pool_index = max_pool_index = cudnn_emitter->build_pooling(
cudnn_emitter->build_pooling(external_function->ctx().get(), external_function->ctx().get(),
CUDNN_POOLING_MAX, CUDNN_POOLING_MAX,
CUDNNEmitter::Prop::Forward, CUDNNEmitter::Prop::Forward,
shape_to_pool, shape_to_pool,
result_shape, result_shape,
max_pool->get_window_movement_strides(), max_pool->get_window_movement_strides(),
max_pool->get_window_shape(), max_pool->get_window_shape(),
padding_below, padding_below,
padding_above); padding_above);
}
writer << "gpu::invoke_primitive(ctx, " << max_pool_index << ", "; writer << "gpu::invoke_primitive(ctx, " << max_pool_index << ", ";
if (pad_required) if (pad_required)
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/gpu/type_info.hpp"
using namespace ngraph;
const runtime::gpu::TypeInfo::TypeDispatch runtime::gpu::TypeInfo::dispatcher{
{"char", std::make_shared<runtime::gpu::TypeInfo_Impl<char>>()},
{"float", std::make_shared<runtime::gpu::TypeInfo_Impl<float>>()},
{"double", std::make_shared<runtime::gpu::TypeInfo_Impl<double>>()},
{"int8_t", std::make_shared<runtime::gpu::TypeInfo_Impl<int8_t>>()},
{"int16_t", std::make_shared<runtime::gpu::TypeInfo_Impl<int16_t>>()},
{"int32_t", std::make_shared<runtime::gpu::TypeInfo_Impl<int32_t>>()},
{"int64_t", std::make_shared<runtime::gpu::TypeInfo_Impl<int64_t>>()},
{"uint8_t", std::make_shared<runtime::gpu::TypeInfo_Impl<uint8_t>>()},
{"uint16_t", std::make_shared<runtime::gpu::TypeInfo_Impl<uint16_t>>()},
{"uint32_t", std::make_shared<runtime::gpu::TypeInfo_Impl<uint32_t>>()},
{"uint64_t", std::make_shared<runtime::gpu::TypeInfo_Impl<uint64_t>>()}};
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <limits>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class TypeInfo
{
public:
virtual ~TypeInfo() = default;
// Helper functions to request information about the underlying c-type
// that is implicitly associated with the registed element::Type
virtual std::string lowest() const = 0;
virtual std::string min() const = 0;
virtual std::string max() const = 0;
using TypeDispatch = std::unordered_map<std::string, std::shared_ptr<TypeInfo>>;
static const std::shared_ptr<TypeInfo>& Get(const element::Type& type)
{
return dispatcher.at(type.c_type_string());
}
static const std::shared_ptr<TypeInfo>& Get(std::string type)
{
return dispatcher.at(type);
}
protected:
template <typename T>
std::string to_string(const T& val) const
{
std::stringstream ss;
ss.precision(std::numeric_limits<T>::digits10 + 2);
ss << val;
return ss.str();
}
private:
static const TypeDispatch dispatcher;
};
template <typename T>
class TypeInfo_Impl : public TypeInfo
{
public:
std::string lowest() const override
{
return to_string<T>(std::numeric_limits<T>::lowest());
}
std::string min() const override
{
return to_string<T>(std::numeric_limits<T>::min());
}
std::string max() const override
{
return to_string<T>(std::numeric_limits<T>::max());
}
};
}
}
}
...@@ -4200,7 +4200,6 @@ TEST(${BACKEND_NAME}, max_pool_2d_2channel_2image) ...@@ -4200,7 +4200,6 @@ TEST(${BACKEND_NAME}, max_pool_2d_2channel_2image)
TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_overpadded) TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_overpadded)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}"); SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
Shape shape_a{1, 1, 5, 5}; Shape shape_a{1, 1, 5, 5};
...@@ -4230,15 +4229,15 @@ TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_overpadded) ...@@ -4230,15 +4229,15 @@ TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_overpadded)
backend->call(f, {result}, {a}); backend->call(f, {result}, {a});
auto min = std::numeric_limits<float>::lowest(); auto min = std::numeric_limits<float>::lowest();
EXPECT_EQ((test::NDArray<float, 4>({{{{min, min, min, min, min}, EXPECT_TRUE(test::all_close(test::NDArray<float, 4>({{{{min, min, min, min, min},
{1, 2, 2, 2, 1}, {1, 2, 2, 2, 1},
{3, 3, 2, 2, 1}, {3, 3, 2, 2, 1},
{3, 3, 2, 1, 1}, {3, 3, 2, 1, 1},
{2, 1, 2, 2, 2}, {2, 1, 2, 2, 2},
{2, 2, 2, 2, 2}, {2, 2, 2, 2, 2},
{2, 2, 1, 0, 0}}}}) {2, 2, 1, 0, 0}}}})
.get_vector()), .get_vector(),
read_vector<float>(result)); read_vector<float>(result)));
} }
TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_padded) TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_padded)
...@@ -4286,7 +4285,6 @@ TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_padded) ...@@ -4286,7 +4285,6 @@ TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_padded)
// values still "win" versus out-of-bounds values), which is good. // values still "win" versus out-of-bounds values), which is good.
TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_padded_negative_values) TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_padded_negative_values)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}"); SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
auto shape_a = Shape{ auto shape_a = Shape{
...@@ -6207,7 +6205,6 @@ TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded_3x3_strided_uneven) ...@@ -6207,7 +6205,6 @@ TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded_3x3_strided_uneven)
TEST(${BACKEND_NAME}, pad_interior_1d) TEST(${BACKEND_NAME}, pad_interior_1d)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{6}; Shape shape_a{6};
auto A = make_shared<op::Parameter>(element::f32, shape_a); auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{}; Shape shape_b{};
...@@ -6238,7 +6235,6 @@ TEST(${BACKEND_NAME}, pad_interior_1d) ...@@ -6238,7 +6235,6 @@ TEST(${BACKEND_NAME}, pad_interior_1d)
TEST(${BACKEND_NAME}, pad_exterior_1d) TEST(${BACKEND_NAME}, pad_exterior_1d)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{6}; Shape shape_a{6};
auto A = make_shared<op::Parameter>(element::f32, shape_a); auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{}; Shape shape_b{};
...@@ -6269,7 +6265,6 @@ TEST(${BACKEND_NAME}, pad_exterior_1d) ...@@ -6269,7 +6265,6 @@ TEST(${BACKEND_NAME}, pad_exterior_1d)
TEST(${BACKEND_NAME}, pad_interior_exterior_1d) TEST(${BACKEND_NAME}, pad_interior_exterior_1d)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{6}; Shape shape_a{6};
auto A = make_shared<op::Parameter>(element::f32, shape_a); auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{}; Shape shape_b{};
...@@ -6301,7 +6296,6 @@ TEST(${BACKEND_NAME}, pad_interior_exterior_1d) ...@@ -6301,7 +6296,6 @@ TEST(${BACKEND_NAME}, pad_interior_exterior_1d)
TEST(${BACKEND_NAME}, pad_interior_exterior_2d) TEST(${BACKEND_NAME}, pad_interior_exterior_2d)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 3}; Shape shape_a{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_a); auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{}; Shape shape_b{};
...@@ -6337,7 +6331,6 @@ TEST(${BACKEND_NAME}, pad_interior_exterior_2d) ...@@ -6337,7 +6331,6 @@ TEST(${BACKEND_NAME}, pad_interior_exterior_2d)
TEST(${BACKEND_NAME}, pad_exterior_2d_0x0) TEST(${BACKEND_NAME}, pad_exterior_2d_0x0)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}"); SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
Shape shape_a{0, 0}; Shape shape_a{0, 0};
...@@ -6373,7 +6366,6 @@ TEST(${BACKEND_NAME}, pad_exterior_2d_0x0) ...@@ -6373,7 +6366,6 @@ TEST(${BACKEND_NAME}, pad_exterior_2d_0x0)
TEST(${BACKEND_NAME}, pad_exterior_2d_0x3) TEST(${BACKEND_NAME}, pad_exterior_2d_0x3)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}"); SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
Shape shape_a{0, 3}; Shape shape_a{0, 3};
...@@ -6409,7 +6401,6 @@ TEST(${BACKEND_NAME}, pad_exterior_2d_0x3) ...@@ -6409,7 +6401,6 @@ TEST(${BACKEND_NAME}, pad_exterior_2d_0x3)
TEST(${BACKEND_NAME}, pad_exterior_2d_3x0) TEST(${BACKEND_NAME}, pad_exterior_2d_3x0)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}"); SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
Shape shape_a{3, 0}; Shape shape_a{3, 0};
...@@ -6445,7 +6436,6 @@ TEST(${BACKEND_NAME}, pad_exterior_2d_3x0) ...@@ -6445,7 +6436,6 @@ TEST(${BACKEND_NAME}, pad_exterior_2d_3x0)
TEST(${BACKEND_NAME}, pad_exterior_4d_1x2x2x2) TEST(${BACKEND_NAME}, pad_exterior_4d_1x2x2x2)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}"); SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
Shape shape_a{1, 2, 2, 2}; Shape shape_a{1, 2, 2, 2};
...@@ -6515,7 +6505,6 @@ TEST(${BACKEND_NAME}, pad_exterior_4d_1x2x2x2) ...@@ -6515,7 +6505,6 @@ TEST(${BACKEND_NAME}, pad_exterior_4d_1x2x2x2)
// we should just count the pre-interior-padding length as zero. // we should just count the pre-interior-padding length as zero.
TEST(${BACKEND_NAME}, pad_interior_exterior_4d_2x0x3x2) TEST(${BACKEND_NAME}, pad_interior_exterior_4d_2x0x3x2)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}"); SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
Shape shape_a{2, 0, 3, 2}; Shape shape_a{2, 0, 3, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment