Unverified Commit 89da71d3 authored by Chris Sullivan's avatar Chris Sullivan Committed by GitHub

Adding support for GPU elementwise ops for arbitrarily many inputs (#618)

* Refactored unary elementwise ops into a single interface
that is adaptable to elementwise ops with arbitrary number of inputs.

* Renamed EmitUnaryElementwise -> EmitElementwise.
Implemented first binary elementwise op (Power).

* Refactored some of the boiler plate code for emitting cuda kernels to nvrtc
out of the emit functions and into the CudaFunctionPool static singleton.
CodeWriter now saves cuda kernels to ./gpu_codegen.

* Added ops Divide, Subtract & Sign to the GPU transformer.
Subtract and Sign both use custom device helper functions which
have math kernels defined for the op in gpu_cuda_kernel_ops.hpp,
and which are built by a new get_device_helper function.
parent 7ab47c2e
......@@ -14,11 +14,18 @@
* limitations under the License.
*******************************************************************************/
#include <cctype>
#include <fstream>
#include <iostream>
#include <string>
#include <unordered_map>
#include "ngraph/file_util.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_builder.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_pool.hpp"
static const std::string s_output_dir = "gpu_codegen";
namespace ngraph
{
namespace runtime
......@@ -31,12 +38,20 @@ namespace ngraph
return pool;
}
void CudaFunctionPool::set(std::string& name, std::shared_ptr<CUfunction> function)
void CudaFunctionPool::set(const std::string& name, const std::string& kernel)
{
m_function_map.insert({name, function});
const char* opts[] = {"--gpu-architecture=compute_35",
"--relocatable-device-code=true"};
std::string filename =
file_util::path_join(s_output_dir, "cuda_kernel_" + name + "_codegen.cu");
std::ofstream out(filename);
out << kernel;
out.close();
m_function_map.insert(
{name, CudaFunctionBuilder::get("cuda_" + name, kernel, 2, opts)});
}
std::shared_ptr<CUfunction> CudaFunctionPool::get(std::string& name)
std::shared_ptr<CUfunction> CudaFunctionPool::get(const std::string& name)
{
auto it = m_function_map.find(name);
if (it != m_function_map.end())
......
......@@ -36,8 +36,8 @@ namespace ngraph
CudaFunctionPool& operator=(CudaFunctionPool const&) = delete;
CudaFunctionPool& operator=(CudaFunctionPool&&) = delete;
void set(std::string& name, std::shared_ptr<CUfunction> function);
std::shared_ptr<CUfunction> get(std::string& name);
void set(const std::string& name, const std::string& kernel);
std::shared_ptr<CUfunction> get(const std::string& name);
protected:
CudaFunctionPool() {}
......
......@@ -14,6 +14,7 @@
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include "ngraph/codegen/code_writer.hpp"
namespace ngraph
{
......@@ -21,51 +22,66 @@ namespace ngraph
{
namespace gpu
{
void CudaKernelBuilder::get_unary_elementwise_op(const std::string& name,
const std::string& data_type,
const std::string& op,
std::string& kernel)
void CudaKernelBuilder::get_elementwise_op(codegen::CodeWriter& writer,
const std::string& name,
const std::string& data_type,
const std::string& op,
const size_t& num_inputs)
{
kernel = R"(
extern "C" __global__
void cuda_)" + name + "(" +
data_type + "* in, " + data_type + "* out, size_t n)\n" + R"({
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
out[tid] =)" + op + "(in[tid]);\n" +
R"(}
})";
return;
}
writer << "extern \"C\" __global__ void cuda_" << name << "(";
for (size_t i = 0; i < num_inputs; i++)
{
writer << data_type << "* in" << i << ", ";
}
writer << data_type << "* out,"
<< "size_t n)\n";
writer << "{\n";
writer.indent++;
{
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x; \n";
writer << "if (tid < n)\n";
writer << "{\n";
writer.indent++;
{
writer << "out[tid] = " << op << "(";
for (size_t i = 0; i < num_inputs - 1; i++)
{
writer << "in" << i << "[tid], ";
}
writer << "in" << num_inputs - 1 << "[tid]);\n";
}
writer.indent--;
writer << "}\n";
}
writer.indent--;
writer << "}\n";
void CudaKernelBuilder::get_binary_elementwise_op(const std::string& name,
const std::string& data_type,
const std::string& op,
std::string& kernel)
{
kernel = R"(
extern "C" __global__
void )" + name + "(" + data_type +
"* in1, " + data_type + "* in2, " + data_type + "* out, size_t n)\n" +
R"({
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
out[tid] = in1[tid] )" + op +
"in2[tid]\n" +
R"(}
})";
return;
}
void
CudaKernelBuilder::get_arbitrary_elementwise_op(const std::string& name,
const std::string& data_type,
const std::vector<std::string>& ops,
std::string& kernel)
void CudaKernelBuilder::get_device_helper(codegen::CodeWriter& writer,
const std::string& name,
const std::string& data_type,
const std::string& math_kernel,
const size_t& num_inputs)
{
kernel = "";
if (math_kernel.size())
{
writer << "__device__ " << data_type << " " << name << "(";
for (size_t i = 0; i < num_inputs - 1; i++)
{
writer << data_type << " x" << i << ", ";
}
writer << data_type << " x" << num_inputs - 1;
writer << ")\n";
writer << "{\n";
writer.indent++;
{
writer << "return " + math_kernel << ";\n";
}
writer.indent--;
writer << "}\n";
}
return;
}
}
......
......@@ -21,6 +21,10 @@
namespace ngraph
{
namespace codegen
{
class CodeWriter;
}
namespace runtime
{
namespace gpu
......@@ -28,20 +32,17 @@ namespace ngraph
class CudaKernelBuilder
{
public:
static void get_unary_elementwise_op(const std::string& name,
const std::string& data_type,
const std::string& op,
std::string& kernel);
static void get_binary_elementwise_op(const std::string& name,
const std::string& data_type,
const std::string& op,
std::string& kernel);
static void get_elementwise_op(codegen::CodeWriter& writer,
const std::string& name,
const std::string& data_type,
const std::string& op,
const size_t& num_inputs);
static void get_arbitrary_elementwise_op(const std::string& name,
const std::string& data_type,
const std::vector<std::string>& ops,
std::string& kernel);
static void get_device_helper(codegen::CodeWriter& writer,
const std::string& name,
const std::string& data_type,
const std::string& math_kernel,
const size_t& num_inputs);
};
}
}
......
......@@ -33,8 +33,6 @@ namespace ngraph
// Create an instance of nvrtcProgram with the code string.
if (CudaFunctionPool::instance().get(name) == nullptr)
{
const char* opts[] = {"--gpu-architecture=compute_35",
"--relocatable-device-code=true"};
std::string kernel;
std::string data_type("float");
......@@ -50,9 +48,7 @@ void cuda_)" + name + "(" + data_type +
out[tid] = in[idx];
}
})";
CudaFunctionPool::instance().set(
name, CudaFunctionBuilder::get("cuda_" + name, kernel, 2, opts));
CudaFunctionPool::instance().set(name, kernel);
}
//convert runtime ptr to driver api ptr
......
......@@ -18,7 +18,6 @@
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/coordinate.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_builder.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_pool.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include "ngraph/strides.hpp"
......@@ -35,27 +34,34 @@ namespace ngraph
void emit_broadcast(
void* in, void* out, size_t repeat_size, size_t repeat_times, size_t count);
template <typename T>
void emit_unary_elementwise_op(void* in, void* out, size_t count, std::string name)
template <typename T, typename... Inputs>
void emit_elementwise_op(std::string name,
size_t count,
CUdeviceptr out,
Inputs&&... inputs)
{
// Create an instance of nvrtcProgram with the code string.
if (CudaFunctionPool::instance().get(name) == nullptr)
{
const char* opts[] = {"--gpu-architecture=compute_35",
"--relocatable-device-code=true"};
std::string kernel;
CudaKernelBuilder::get_unary_elementwise_op(
name, "float", CudaOpMap<T>::op, kernel);
CudaFunctionPool::instance().set(
name, CudaFunctionBuilder::get("cuda_" + name, kernel, 2, opts));
codegen::CodeWriter writer;
if (CudaOpMap<T>::math_kernel)
{
CudaKernelBuilder::get_device_helper(writer,
CudaOpMap<T>::op,
CudaOpMap<T>::type,
CudaOpMap<T>::math_kernel,
sizeof...(inputs));
}
CudaKernelBuilder::get_elementwise_op(
writer, name, CudaOpMap<T>::type, CudaOpMap<T>::op, sizeof...(inputs));
std::string kernel = writer.get_code();
CudaFunctionPool::instance().set(name, kernel);
}
//convert runtime ptr to driver api ptr
CUdeviceptr d_ptr_in, d_ptr_out;
d_ptr_in = (CUdeviceptr)in;
d_ptr_out = (CUdeviceptr)out;
void* args_list[] = {&d_ptr_in, &d_ptr_out, &count};
void* args_list[] = {&inputs..., &out, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*CudaFunctionPool::instance().get(name).get(),
count,
1,
......
......@@ -34,13 +34,25 @@ namespace ngraph
class Sinh;
class Tan;
class Tanh;
class Power;
class Subtract;
class Divide;
class Sign;
// requires different input and output types
class Convert;
class Equal;
class NotEqual;
class Greater;
class GreaterEq;
class Less;
class LessEq;
// Unimplemented or unused in favor of cuDNN impl.
class Max;
class Min;
class Negative;
class Not;
class Sign;
class Sqrt;
}
namespace runtime
......@@ -51,102 +63,168 @@ namespace ngraph
struct CudaOpMap<ngraph::op::Abs>
{
static constexpr const char* op = "fabsf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Acos>
{
static constexpr const char* op = "acosf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Asin>
{
static constexpr const char* op = "asinf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Atan>
{
static constexpr const char* op = "atanf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Ceiling>
{
static constexpr const char* op = "ceilf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Cos>
{
static constexpr const char* op = "cosf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Cosh>
{
static constexpr const char* op = "coshf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Exp>
{
static constexpr const char* op = "expf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Floor>
{
static constexpr const char* op = "floorf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Log>
{
static constexpr const char* op = "logf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Max>
{
static constexpr const char* op = "fmaxf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Min>
{
static constexpr const char* op = "fminf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Sin>
{
static constexpr const char* op = "sinf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Sinh>
{
static constexpr const char* op = "sinhf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Sqrt>
{
static constexpr const char* op = "sqrtf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Tan>
{
static constexpr const char* op = "tanf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Tanh>
{
static constexpr const char* op = "tanhf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Power>
{
static constexpr const char* op = "powf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Subtract>
{
static constexpr const char* op = "subtractf";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = "x0-x1";
};
template <>
struct CudaOpMap<ngraph::op::Divide>
{
static constexpr const char* op = "fdividef";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = nullptr;
};
template <>
struct CudaOpMap<ngraph::op::Sign>
{
static constexpr const char* op = "sign";
static constexpr const char* type = "float";
static constexpr const char* math_kernel = "(x0 > 0) - (x0 < 0)";
};
}
}
......
......@@ -105,39 +105,30 @@ namespace ngraph
{
namespace gpu
{
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::Abs)
void runtime::gpu::GPU_Emitter::EmitElementwise(
GPU_ExternalFunction* external_function,
codegen::CodeWriter& writer,
const ngraph::Node* n,
const vector<runtime::gpu::GPU_TensorViewWrapper>& args,
const vector<runtime::gpu::GPU_TensorViewWrapper>& out)
{
if (out[0].get_size() == 0)
{
return;
}
writer << "{ // " << node->get_name() << "\n";
writer << "{ // " << n->get_name() << "\n";
writer.indent++;
writer << "int count = " << out[0].get_size() << ";\n";
writer << "ngraph::runtime::gpu::emit_abs((void*) " << args[0].get_name()
<< ", (void*) " << out[0].get_name() << ", count);\n";
writer.indent--;
writer << "}\n";
}
void GPU_Emitter::EmitUnaryElementwise(GPU_ExternalFunction* external_function,
codegen::CodeWriter& writer,
const ngraph::Node* node,
const std::vector<GPU_TensorViewWrapper>& args,
const std::vector<GPU_TensorViewWrapper>& out)
{
if (out[0].get_size() == 0)
writer << "if(count == 0) return;\n";
writer << "ngraph::runtime::gpu::emit_elementwise_op<ngraph::op::"
<< n->description() << ">(\"" << n->description() << "\""
<< ", count"
<< ", (CUdeviceptr) " << out[0].get_name();
for (size_t i = 0; i < args.size(); i++)
{
return;
writer << ", (CUdeviceptr) " << args[i].get_name();
}
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer << "int count = " << out[0].get_size() << ";\n";
writer << "if(count == 0) return;\n";
writer << "ngraph::runtime::gpu::emit_unary_elementwise_op<ngraph::op::"
<< node->description() << ">((void*) " << args[0].get_name() << ", (void*) "
<< out[0].get_name() << ", count, \"" << node->description() << "\");\n";
writer << ");\n";
writer.indent--;
writer << "}\n";
}
......
......@@ -58,11 +58,11 @@ namespace ngraph
{
}
static void EmitUnaryElementwise(GPU_ExternalFunction* external_function,
codegen::CodeWriter& writer,
const ngraph::Node* node,
const std::vector<GPU_TensorViewWrapper>& args,
const std::vector<GPU_TensorViewWrapper>& out);
static void EmitElementwise(GPU_ExternalFunction* external_function,
codegen::CodeWriter& writer,
const ngraph::Node* node,
const std::vector<GPU_TensorViewWrapper>& args,
const std::vector<GPU_TensorViewWrapper>& out);
};
}
}
......
......@@ -170,9 +170,9 @@ namespace ngraph
{TI(ngraph::op::Dot), &GPU_Emitter::emit<ngraph::op::Dot>},
{TI(ngraph::op::Multiply), &GPU_Emitter::emit<ngraph::op::Multiply>},
{TI(ngraph::op::Parameter), &GPU_Emitter::nop},
{TI(ngraph::op::Abs), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Abs), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Concat), &GPU_Emitter::emit<ngraph::op::Concat>},
{TI(ngraph::op::Divide), &GPU_Emitter::emit<ngraph::op::Divide>},
{TI(ngraph::op::Divide), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Equal), &GPU_Emitter::emit<ngraph::op::Equal>},
{TI(ngraph::op::GetOutputElement),
&GPU_Emitter::emit<ngraph::op::GetOutputElement>},
......@@ -180,44 +180,44 @@ namespace ngraph
{TI(ngraph::op::GreaterEq), &GPU_Emitter::emit<ngraph::op::GreaterEq>},
{TI(ngraph::op::Less), &GPU_Emitter::emit<ngraph::op::Less>},
{TI(ngraph::op::LessEq), &GPU_Emitter::emit<ngraph::op::LessEq>},
{TI(ngraph::op::Log), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Log), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Maximum), &GPU_Emitter::emit<ngraph::op::Maximum>},
{TI(ngraph::op::Minimum), &GPU_Emitter::emit<ngraph::op::Minimum>},
{TI(ngraph::op::Negative), &GPU_Emitter::emit<ngraph::op::Negative>},
{TI(ngraph::op::NotEqual), &GPU_Emitter::emit<ngraph::op::NotEqual>},
{TI(ngraph::op::Power), &GPU_Emitter::emit<ngraph::op::Power>},
{TI(ngraph::op::Power), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Select), &GPU_Emitter::emit<ngraph::op::Select>},
{TI(ngraph::op::Subtract), &GPU_Emitter::emit<ngraph::op::Subtract>},
{TI(ngraph::op::Subtract), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Broadcast), &GPU_Emitter::emit<ngraph::op::Broadcast>},
{TI(ngraph::op::Convert), &GPU_Emitter::emit<ngraph::op::Convert>},
{TI(ngraph::op::Constant), &GPU_Emitter::emit<ngraph::op::Constant>},
{TI(ngraph::op::Reshape), &GPU_Emitter::emit<ngraph::op::Reshape>},
{TI(ngraph::op::FunctionCall), &GPU_Emitter::emit<ngraph::op::FunctionCall>},
{TI(ngraph::op::Reduce), &GPU_Emitter::emit<ngraph::op::Reduce>},
{TI(ngraph::op::Sign), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Sign), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Slice), &GPU_Emitter::emit<ngraph::op::Slice>},
{TI(ngraph::op::Sum), &GPU_Emitter::emit<ngraph::op::Sum>},
{TI(ngraph::op::Exp), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Sin), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Sinh), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Cos), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Cosh), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Tan), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Tanh), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Asin), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Acos), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Atan), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Exp), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Sin), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Sinh), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Cos), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Cosh), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Tan), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Tanh), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Asin), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Acos), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Atan), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::ReplaceSlice), &GPU_Emitter::emit<ngraph::op::ReplaceSlice>},
{TI(ngraph::op::OneHot), &GPU_Emitter::emit<ngraph::op::OneHot>},
{TI(ngraph::op::Floor), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Ceiling), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Floor), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Ceiling), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Sqrt), &GPU_Emitter::emit<ngraph::op::Sqrt>},
{TI(ngraph::op::Convolution), &GPU_Emitter::emit<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropFilters),
&GPU_Emitter::emit<ngraph::op::ConvolutionBackpropFilters>},
{TI(ngraph::op::ConvolutionBackpropData),
&GPU_Emitter::emit<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::Not), &GPU_Emitter::EmitUnaryElementwise},
{TI(ngraph::op::Not), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::MaxPool), &GPU_Emitter::emit<ngraph::op::MaxPool>},
{TI(ngraph::op::Reverse), &GPU_Emitter::emit<ngraph::op::Reverse>},
{TI(ngraph::op::Result), &GPU_Emitter::emit<ngraph::op::Result>},
......@@ -853,4 +853,4 @@ using namespace std;
}
}
}
}
\ No newline at end of file
}
......@@ -608,7 +608,6 @@ TEST(${BACKEND_NAME}, concat_5d)
TEST(${BACKEND_NAME}, divide)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend();
......@@ -1513,7 +1512,6 @@ TEST(${BACKEND_NAME}, select)
TEST(${BACKEND_NAME}, subtract)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
......@@ -3605,7 +3603,6 @@ TEST(${BACKEND_NAME}, sum_3d_to_vector_stable)
TEST(${BACKEND_NAME}, sign)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Sign>(A), op::ParameterVector{A});
......@@ -3626,7 +3623,6 @@ TEST(${BACKEND_NAME}, sign)
TEST(${BACKEND_NAME}, power)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
......@@ -3645,7 +3641,7 @@ TEST(${BACKEND_NAME}, power)
auto result = backend->make_primary_tensor_view(element::f32, shape);
cf->call({a, b}, {result});
EXPECT_EQ((vector<float>{1, 1, 729, 125}), read_vector<float>(result));
EXPECT_TRUE(test::all_close(vector<float>{1, 1, 729, 125}, read_vector<float>(result)));
}
TEST(${BACKEND_NAME}, constant_equality_bool)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment