Commit 61df6725 authored by Rob Earhart's avatar Rob Earhart Committed by Robert Kimball

[PlaidML] Specialize within namespaces (for Linux) (#1948)

parent 5698fa75
...@@ -28,10 +28,16 @@ ...@@ -28,10 +28,16 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp" #include "ngraph/runtime/plaidml/plaidml_translate.hpp"
// Abs performs a simple elementwise absolute value. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Abs>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// Abs performs a simple elementwise absolute value.
template <>
void Impl<op::Abs>::operator()()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -39,12 +45,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Abs>::operator()() ...@@ -39,12 +45,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Abs>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "abs(I)"}) .add(builder::Elementwise{"O", "abs(I)"})
.finalize()); .finalize());
} }
// Add performs a simple elementwise addition. // Add performs a simple elementwise addition.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Add>::operator()() void Impl<op::Add>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -53,12 +59,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Add>::operator()() ...@@ -53,12 +59,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Add>::operator()()
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A + B"}) .add(builder::Elementwise{"C", "A + B"})
.finalize()); .finalize());
} }
// Ceiling performs a simple elementwise ceiling. // Ceiling performs a simple elementwise ceiling.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Ceiling>::operator()() void Impl<op::Ceiling>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -66,12 +72,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Ceiling>::operator()() ...@@ -66,12 +72,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Ceiling>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "ceil(I)"}) .add(builder::Elementwise{"O", "ceil(I)"})
.finalize()); .finalize());
} }
// Divide performs a simple elementwise division. // Divide performs a simple elementwise division.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Divide>::operator()() void Impl<op::Divide>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -80,12 +86,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Divide>::operator()() ...@@ -80,12 +86,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Divide>::operator()()
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A / B"}) .add(builder::Elementwise{"C", "A / B"})
.finalize()); .finalize());
} }
// Floor performs a simple elementwise floor. // Floor performs a simple elementwise floor.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Floor>::operator()() void Impl<op::Floor>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -93,12 +99,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Floor>::operator()() ...@@ -93,12 +99,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Floor>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "floor(I)"}) .add(builder::Elementwise{"O", "floor(I)"})
.finalize()); .finalize());
} }
// Multiply performs a simple elementwise multiplication. // Multiply performs a simple elementwise multiplication.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Multiply>::operator()() void Impl<op::Multiply>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -107,12 +113,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Multiply>::operator()() ...@@ -107,12 +113,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Multiply>::operator()()
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A * B"}) .add(builder::Elementwise{"C", "A * B"})
.finalize()); .finalize());
} }
// Negative performs a simple elementwise negation. // Negative performs a simple elementwise negation.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Negative>::operator()() void Impl<op::Negative>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -120,12 +126,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Negative>::operator()() ...@@ -120,12 +126,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Negative>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "-I"}) .add(builder::Elementwise{"O", "-I"})
.finalize()); .finalize());
} }
// Relu implements a simple elementwise rectified linear unit. // Relu implements a simple elementwise rectified linear unit.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Relu>::operator()() void Impl<op::Relu>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -133,12 +139,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Relu>::operator()() ...@@ -133,12 +139,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Relu>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "relu(I)"}) .add(builder::Elementwise{"O", "relu(I)"})
.finalize()); .finalize());
} }
// ReluBackprop computes the derivative of Relu. // ReluBackprop computes the derivative of Relu.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::ReluBackprop>::operator()() void Impl<op::ReluBackprop>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -147,12 +153,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReluBackprop>::operator()() ...@@ -147,12 +153,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReluBackprop>::operator()()
.add(builder::Output{"DI"}) .add(builder::Output{"DI"})
.add(builder::Elementwise{"DI", "I > 0 ? DO : 0"}) .add(builder::Elementwise{"DI", "I > 0 ? DO : 0"})
.finalize()); .finalize());
} }
// Sigmoid computes a standard ML sigmoid: 1/(1+exp(-X)) // Sigmoid computes a standard ML sigmoid: 1/(1+exp(-X))
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Sigmoid>::operator()() void Impl<op::Sigmoid>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -160,13 +166,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sigmoid>::operator()() ...@@ -160,13 +166,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sigmoid>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "1/(1+exp(-I))"}) .add(builder::Elementwise{"O", "1/(1+exp(-I))"})
.finalize()); .finalize());
} }
// SigmoidBackprop computes the derivative of a standard ML // SigmoidBackprop computes the derivative of a standard ML
// sigmoid: dOutput * sigmoid(X) * (1-sigmoid(X)) // sigmoid: dOutput * sigmoid(X) * (1-sigmoid(X))
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::SigmoidBackprop>::operator()() void Impl<op::SigmoidBackprop>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -176,26 +182,27 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::SigmoidBackprop>::operator()() ...@@ -176,26 +182,27 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::SigmoidBackprop>::operator()()
.add(builder::Elementwise{"O", "1/(1+exp(-I))"}) .add(builder::Elementwise{"O", "1/(1+exp(-I))"})
.add(builder::Elementwise{"DI", "DO * O * (1-O)"}) .add(builder::Elementwise{"DI", "DO * O * (1-O)"})
.finalize()); .finalize());
} }
// Sign returns the sign of an element. // Sign returns the sign of an element.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Sign>::operator()() void Impl<op::Sign>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
.add(builder::Input{op_input(0), "I"}) .add(builder::Input{op_input(0), "I"})
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"S", "(I < 0) ? -1 : ((I > 0) ? 1 : 0)"}) .add(builder::Elementwise{"S", "(I < 0) ? -1 : ((I > 0) ? 1 : 0)"})
.add(builder::Elementwise{"O", tile_converter("S", op().get_element_type())}) .add(builder::Elementwise{
"O", tile_converter("S", op().get_element_type())})
.finalize()); .finalize());
} }
// Subtract performs a simple elementwise subtraction. // Subtract performs a simple elementwise subtraction.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Subtract>::operator()() void Impl<op::Subtract>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -204,22 +211,24 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Subtract>::operator()() ...@@ -204,22 +211,24 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Subtract>::operator()()
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A - B"}) .add(builder::Elementwise{"C", "A - B"})
.finalize()); .finalize());
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::Abs>::Registration register_abs; Impl<op::Abs>::Registration register_abs;
ngraph::runtime::plaidml::Impl<ngraph::op::Add>::Registration register_add; Impl<op::Add>::Registration register_add;
ngraph::runtime::plaidml::Impl<ngraph::op::Ceiling>::Registration register_ceiling; Impl<op::Ceiling>::Registration register_ceiling;
ngraph::runtime::plaidml::Impl<ngraph::op::Divide>::Registration register_divide; Impl<op::Divide>::Registration register_divide;
ngraph::runtime::plaidml::Impl<ngraph::op::Floor>::Registration register_floor; Impl<op::Floor>::Registration register_floor;
ngraph::runtime::plaidml::Impl<ngraph::op::Multiply>::Registration register_multiply; Impl<op::Multiply>::Registration register_multiply;
ngraph::runtime::plaidml::Impl<ngraph::op::Negative>::Registration register_negative; Impl<op::Negative>::Registration register_negative;
ngraph::runtime::plaidml::Impl<ngraph::op::Relu>::Registration register_relu; Impl<op::Relu>::Registration register_relu;
ngraph::runtime::plaidml::Impl<ngraph::op::ReluBackprop>::Registration register_relu_backprop; Impl<op::ReluBackprop>::Registration register_relu_backprop;
ngraph::runtime::plaidml::Impl<ngraph::op::Sigmoid>::Registration register_sigmoid; Impl<op::Sigmoid>::Registration register_sigmoid;
ngraph::runtime::plaidml::Impl<ngraph::op::SigmoidBackprop>::Registration Impl<op::SigmoidBackprop>::Registration register_sigmoid_backprop;
register_sigmoid_backprop; Impl<op::Sign>::Registration register_sign;
ngraph::runtime::plaidml::Impl<ngraph::op::Sign>::Registration register_sign; Impl<op::Subtract>::Registration register_subtract;
ngraph::runtime::plaidml::Impl<ngraph::op::Subtract>::Registration register_subtract; }
}
}
} }
...@@ -18,11 +18,17 @@ ...@@ -18,11 +18,17 @@
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// BatchNormInference implements batch normalization for inference, in namespace ngraph
// which the mean and variance to use are supplied.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// BatchNormInference implements batch normalization for inference, in
// which the mean and variance to use are supplied.
template <>
void Impl<op::BatchNormInference>::operator()()
{
auto& input_shape = op().get_input_shape(2); auto& input_shape = op().get_input_shape(2);
check_inputs(5); check_inputs(5);
check_outputs(1); check_outputs(1);
...@@ -45,12 +51,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()( ...@@ -45,12 +51,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()(
if (input_shape.size() <= 2) if (input_shape.size() <= 2)
{ {
f.add(builder::Elementwise{"GammaP", "Gamma"}).add(builder::Elementwise{"BetaP", "Beta"}); f.add(builder::Elementwise{"GammaP", "Gamma"})
.add(builder::Elementwise{"BetaP", "Beta"});
} }
else else
{ {
f.add(builder::Elementwise{"GammaP", std::string{"reshape(Gamma, C"} + ones + ")"}) f.add(builder::Elementwise{"GammaP",
.add(builder::Elementwise{"BetaP", std::string{"reshape(Beta, C"} + ones + ")"}); std::string{"reshape(Gamma, C"} + ones + ")"})
.add(builder::Elementwise{"BetaP",
std::string{"reshape(Beta, C"} + ones + ")"});
} }
if (input_shape.size() <= 2) if (input_shape.size() <= 2)
...@@ -59,7 +68,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()( ...@@ -59,7 +68,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()(
} }
else else
{ {
f.add(builder::Elementwise{"MeanP", std::string{"reshape(Mean, C"} + ones + ")"}); f.add(
builder::Elementwise{"MeanP", std::string{"reshape(Mean, C"} + ones + ")"});
} }
if (input_shape.size() <= 2) if (input_shape.size() <= 2)
...@@ -68,24 +78,26 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()( ...@@ -68,24 +78,26 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()(
} }
else else
{ {
f.add(builder::Elementwise{"VarianceP", std::string{"reshape(Variance, C"} + ones + ")"}); f.add(builder::Elementwise{"VarianceP",
std::string{"reshape(Variance, C"} + ones + ")"});
} }
f.add(builder::Elementwise{"Normalized", f.add(builder::Elementwise{"Normalized",
"(((Input-MeanP) / sqrt(VarianceP + " + "(((Input-MeanP) / sqrt(VarianceP + " +
std::to_string(op().get_eps_value()) + ")) * GammaP) + BetaP"}); std::to_string(op().get_eps_value()) +
")) * GammaP) + BetaP"});
auto app = f.finalize(); auto app = f.finalize();
set_output(app); set_output(app);
} }
// BatchNormTraining implements batch normalization for training, in // BatchNormTraining implements batch normalization for training, in
// which the mean and variance are to be computed from the supplied // which the mean and variance are to be computed from the supplied
// input. // input.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()() void Impl<op::BatchNormTraining>::operator()()
{ {
auto& input_shape = op().get_input_shape(2); auto& input_shape = op().get_input_shape(2);
check_inputs(3); check_inputs(3);
check_outputs(3); check_outputs(3);
...@@ -108,12 +120,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()() ...@@ -108,12 +120,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
if (input_shape.size() <= 2) if (input_shape.size() <= 2)
{ {
f.add(builder::Elementwise{"GammaP", "Gamma"}).add(builder::Elementwise{"BetaP", "Beta"}); f.add(builder::Elementwise{"GammaP", "Gamma"})
.add(builder::Elementwise{"BetaP", "Beta"});
} }
else else
{ {
f.add(builder::Elementwise{"GammaP", std::string{"reshape(Gamma, C"} + ones + ")"}) f.add(builder::Elementwise{"GammaP",
.add(builder::Elementwise{"BetaP", std::string{"reshape(Beta, C"} + ones + ")"}); std::string{"reshape(Gamma, C"} + ones + ")"})
.add(builder::Elementwise{"BetaP",
std::string{"reshape(Beta, C"} + ones + ")"});
} }
if (input_shape.size() <= 2) if (input_shape.size() <= 2)
...@@ -131,7 +146,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()() ...@@ -131,7 +146,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
} }
f.add(builder::UnaryContraction{"+"} f.add(builder::UnaryContraction{"+"}
.set(builder::ContractionOutput{"SumInput"}.add_indices({"c"}).add_dims({"C"})) .set(builder::ContractionOutput{"SumInput"}.add_indices({"c"}).add_dims(
{"C"}))
.set(builder::ContractionInput{"Input"} .set(builder::ContractionInput{"Input"}
.add_indices({"b", "c"}) .add_indices({"b", "c"})
.add_indices("di", 3, input_shape.size() + 1))); .add_indices("di", 3, input_shape.size() + 1)));
...@@ -143,13 +159,16 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()() ...@@ -143,13 +159,16 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
} }
else else
{ {
f.add(builder::Elementwise{"MeanP", std::string{"reshape(Mean, C"} + ones + ")"}); f.add(
builder::Elementwise{"MeanP", std::string{"reshape(Mean, C"} + ones + ")"});
} }
f.add(builder::Elementwise{"DiffV", "(Input - MeanP)"}) f.add(builder::Elementwise{"DiffV", "(Input - MeanP)"})
.add(builder::Elementwise{"SqDiffV", "DiffV*DiffV"}) .add(builder::Elementwise{"SqDiffV", "DiffV*DiffV"})
.add(builder::UnaryContraction{"+"} .add(builder::UnaryContraction{"+"}
.set(builder::ContractionOutput{"SumSqDiffV"}.add_indices({"c"}).add_dims({"C"})) .set(builder::ContractionOutput{"SumSqDiffV"}
.add_indices({"c"})
.add_dims({"C"}))
.set(builder::ContractionInput{"SqDiffV"} .set(builder::ContractionInput{"SqDiffV"}
.add_indices({"b", "c"}) .add_indices({"b", "c"})
.add_indices("di", 3, input_shape.size() + 1))) .add_indices("di", 3, input_shape.size() + 1)))
...@@ -161,23 +180,25 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()() ...@@ -161,23 +180,25 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
} }
else else
{ {
f.add(builder::Elementwise{"VarianceP", std::string{"reshape(Variance, C"} + ones + ")"}); f.add(builder::Elementwise{"VarianceP",
std::string{"reshape(Variance, C"} + ones + ")"});
} }
f.add(builder::Elementwise{"Normalized", f.add(builder::Elementwise{"Normalized",
"(((Input-MeanP) / sqrt(VarianceP + " + "(((Input-MeanP) / sqrt(VarianceP + " +
std::to_string(op().get_eps_value()) + ")) * GammaP) + BetaP"}); std::to_string(op().get_eps_value()) +
")) * GammaP) + BetaP"});
auto app = f.finalize(); auto app = f.finalize();
set_output(0, app.get_output(0)); set_output(0, app.get_output(0));
set_output(1, app.get_output(1)); set_output(1, app.get_output(1));
set_output(2, app.get_output(2)); set_output(2, app.get_output(2));
} }
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::operator()() void Impl<op::BatchNormTrainingBackprop>::operator()()
{ {
// WARNING: I'm unconvinced that we have sufficient test converage for BatchNorm // WARNING: I'm unconvinced that we have sufficient test converage for BatchNorm
// backprop and in particular I'm concerned that Gamma/Beta and Mean/Var could be // backprop and in particular I'm concerned that Gamma/Beta and Mean/Var could be
// swapped without the tests catching it. // swapped without the tests catching it.
...@@ -232,10 +253,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper ...@@ -232,10 +253,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
.set(builder::ContractionInput{"Input"} .set(builder::ContractionInput{"Input"}
.add_indices({"n", "c"}) .add_indices({"n", "c"})
.add_indices("x", 3, input_shape.size() + 1))); .add_indices("x", 3, input_shape.size() + 1)));
f.add(builder::Elementwise{"BatchMean", "BatchMeanNumerator / " + reduction_dims.str()}); f.add(builder::Elementwise{"BatchMean",
"BatchMeanNumerator / " + reduction_dims.str()});
f.add(builder::Elementwise{"NegBatchMean", "-BatchMean"}); f.add(builder::Elementwise{"NegBatchMean", "-BatchMean"});
f.add( f.add(builder::BinaryContraction{"=", "+"}
builder::BinaryContraction{"=", "+"}
.set(builder::ContractionOutput{"Deviation"} .set(builder::ContractionOutput{"Deviation"}
.add_indices({"n", "c"}) .add_indices({"n", "c"})
.add_indices("x", 3, input_shape.size() + 1) .add_indices("x", 3, input_shape.size() + 1)
...@@ -244,7 +265,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper ...@@ -244,7 +265,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
.set_lhs(builder::ContractionInput{"Input"} .set_lhs(builder::ContractionInput{"Input"}
.add_indices({"n", "c"}) .add_indices({"n", "c"})
.add_indices("x", 3, input_shape.size() + 1)) .add_indices("x", 3, input_shape.size() + 1))
.set_rhs(builder::ContractionInput{"NegBatchMean"}.add_indices({"0", "c", "0", "0"}))); .set_rhs(builder::ContractionInput{"NegBatchMean"}.add_indices(
{"0", "c", "0", "0"})));
f.add(builder::BinaryContraction{"+", "*"} f.add(builder::BinaryContraction{"+", "*"}
.set(builder::ContractionOutput{"BatchVarNumerator"} .set(builder::ContractionOutput{"BatchVarNumerator"}
.add_indices({"0", "c", "0", "0"}) .add_indices({"0", "c", "0", "0"})
...@@ -255,7 +277,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper ...@@ -255,7 +277,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
.set_rhs(builder::ContractionInput{"Deviation"} .set_rhs(builder::ContractionInput{"Deviation"}
.add_indices({"n", "c"}) .add_indices({"n", "c"})
.add_indices("x", 3, input_shape.size() + 1))); .add_indices("x", 3, input_shape.size() + 1)));
f.add(builder::Elementwise{"BatchVar", "BatchVarNumerator / " + reduction_dims.str()}); f.add(builder::Elementwise{"BatchVar",
"BatchVarNumerator / " + reduction_dims.str()});
f.add(builder::Elementwise{"BatchStdDev", "sqrt(BatchVar + " + epsilon + ")"}); f.add(builder::Elementwise{"BatchStdDev", "sqrt(BatchVar + " + epsilon + ")"});
f.add(builder::Elementwise{"NormedInput", "(Input - BatchMean) / BatchStdDev"}); f.add(builder::Elementwise{"NormedInput", "(Input - BatchMean) / BatchStdDev"});
...@@ -266,12 +289,14 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper ...@@ -266,12 +289,14 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
f.add(builder::Elementwise{"DNormedInput", "DOutput * BroadcastGamma"}); f.add(builder::Elementwise{"DNormedInput", "DOutput * BroadcastGamma"});
f.add(builder::UnaryContraction{"+"} f.add(builder::UnaryContraction{"+"}
.set(builder::ContractionOutput{"SumDOutput"}.add_indices({"c"}).add_dims({"C"})) .set(builder::ContractionOutput{"SumDOutput"}.add_indices({"c"}).add_dims(
{"C"}))
.set(builder::ContractionInput{"DOutput"} .set(builder::ContractionInput{"DOutput"}
.add_indices({"n", "c"}) .add_indices({"n", "c"})
.add_indices("x", 3, input_shape.size() + 1))); .add_indices("x", 3, input_shape.size() + 1)));
f.add(builder::BinaryContraction{"+", "*"} f.add(builder::BinaryContraction{"+", "*"}
.set(builder::ContractionOutput{"DGamma"}.add_indices({"c"}).add_dims({"C"})) .set(builder::ContractionOutput{"DGamma"}.add_indices({"c"}).add_dims(
{"C"}))
.set_lhs(builder::ContractionInput{"DOutput"} .set_lhs(builder::ContractionInput{"DOutput"}
.add_indices({"n", "c"}) .add_indices({"n", "c"})
.add_indices("x", 3, input_shape.size() + 1)) .add_indices("x", 3, input_shape.size() + 1))
...@@ -295,14 +320,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper ...@@ -295,14 +320,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
set_output(0, app.get_output(0)); set_output(0, app.get_output(0));
set_output(1, app.get_output(1)); set_output(1, app.get_output(1));
set_output(2, app.get_output(2)); set_output(2, app.get_output(2));
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::Registration Impl<op::BatchNormInference>::Registration register_batch_norm_inference;
register_batch_norm_inference; Impl<op::BatchNormTraining>::Registration register_batch_norm_training;
ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::Registration Impl<op::BatchNormTrainingBackprop>::Registration
register_batch_norm_training;
ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::Registration
register_batch_norm_training_backprop; register_batch_norm_training_backprop;
}
}
}
} }
...@@ -24,10 +24,16 @@ ...@@ -24,10 +24,16 @@
#include "ngraph/op/not_equal.hpp" #include "ngraph/op/not_equal.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Equal performs a simple elementwise equality. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Equal>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// Equal performs a simple elementwise equality.
template <>
void Impl<op::Equal>::operator()()
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -37,12 +43,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Equal>::operator()() ...@@ -37,12 +43,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Equal>::operator()()
.add(builder::Elementwise{"C", "A == B"}) .add(builder::Elementwise{"C", "A == B"})
.finalize(), .finalize(),
TensorContents::LOGICAL); TensorContents::LOGICAL);
} }
// Greater performs a simple elementwise greater-than comparison. // Greater performs a simple elementwise greater-than comparison.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Greater>::operator()() void Impl<op::Greater>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -52,12 +58,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Greater>::operator()() ...@@ -52,12 +58,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Greater>::operator()()
.add(builder::Elementwise{"C", "A > B"}) .add(builder::Elementwise{"C", "A > B"})
.finalize(), .finalize(),
TensorContents::LOGICAL); TensorContents::LOGICAL);
} }
// GreaterEq performs a simple elementwise greater-than-or-equal-to comparison. // GreaterEq performs a simple elementwise greater-than-or-equal-to comparison.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::GreaterEq>::operator()() void Impl<op::GreaterEq>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -67,12 +73,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::GreaterEq>::operator()() ...@@ -67,12 +73,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::GreaterEq>::operator()()
.add(builder::Elementwise{"C", "A >= B"}) .add(builder::Elementwise{"C", "A >= B"})
.finalize(), .finalize(),
TensorContents::LOGICAL); TensorContents::LOGICAL);
} }
// Less performs a simple elementwise less-than comparison. // Less performs a simple elementwise less-than comparison.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Less>::operator()() void Impl<op::Less>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -82,12 +88,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Less>::operator()() ...@@ -82,12 +88,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Less>::operator()()
.add(builder::Elementwise{"C", "A < B"}) .add(builder::Elementwise{"C", "A < B"})
.finalize(), .finalize(),
TensorContents::LOGICAL); TensorContents::LOGICAL);
} }
// LessEq performs a simple elementwise less-than-or-equal-to comparison. // LessEq performs a simple elementwise less-than-or-equal-to comparison.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::LessEq>::operator()() void Impl<op::LessEq>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -97,12 +103,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::LessEq>::operator()() ...@@ -97,12 +103,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::LessEq>::operator()()
.add(builder::Elementwise{"C", "A <= B"}) .add(builder::Elementwise{"C", "A <= B"})
.finalize(), .finalize(),
TensorContents::LOGICAL); TensorContents::LOGICAL);
} }
// Maximum performs a simple elementwise maximum. // Maximum performs a simple elementwise maximum.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Maximum>::operator()() void Impl<op::Maximum>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -111,12 +117,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Maximum>::operator()() ...@@ -111,12 +117,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Maximum>::operator()()
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "max(A, B)"}) .add(builder::Elementwise{"C", "max(A, B)"})
.finalize()); .finalize());
} }
// Minimum performs a simple elementwise minimum. // Minimum performs a simple elementwise minimum.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Minimum>::operator()() void Impl<op::Minimum>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -125,12 +131,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Minimum>::operator()() ...@@ -125,12 +131,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Minimum>::operator()()
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "min(A, B)"}) .add(builder::Elementwise{"C", "min(A, B)"})
.finalize()); .finalize());
} }
// NotEqual performs a simple elementwise not-equality. // NotEqual performs a simple elementwise not-equality.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::NotEqual>::operator()() void Impl<op::NotEqual>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -140,16 +146,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::NotEqual>::operator()() ...@@ -140,16 +146,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::NotEqual>::operator()()
.add(builder::Elementwise{"C", "A != B"}) .add(builder::Elementwise{"C", "A != B"})
.finalize(), .finalize(),
TensorContents::LOGICAL); TensorContents::LOGICAL);
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::Equal>::Registration register_equal; Impl<op::Equal>::Registration register_equal;
ngraph::runtime::plaidml::Impl<ngraph::op::Greater>::Registration register_greater; Impl<op::Greater>::Registration register_greater;
ngraph::runtime::plaidml::Impl<ngraph::op::GreaterEq>::Registration register_greater_eq; Impl<op::GreaterEq>::Registration register_greater_eq;
ngraph::runtime::plaidml::Impl<ngraph::op::Less>::Registration register_less; Impl<op::Less>::Registration register_less;
ngraph::runtime::plaidml::Impl<ngraph::op::LessEq>::Registration register_less_eq; Impl<op::LessEq>::Registration register_less_eq;
ngraph::runtime::plaidml::Impl<ngraph::op::Maximum>::Registration register_maximum; Impl<op::Maximum>::Registration register_maximum;
ngraph::runtime::plaidml::Impl<ngraph::op::Minimum>::Registration register_minimum; Impl<op::Minimum>::Registration register_minimum;
ngraph::runtime::plaidml::Impl<ngraph::op::NotEqual>::Registration register_not_equal; Impl<op::NotEqual>::Registration register_not_equal;
}
}
}
} }
...@@ -17,10 +17,16 @@ ...@@ -17,10 +17,16 @@
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Concat views a tensor as a new type. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// Concat views a tensor as a new type.
template <>
void Impl<op::Concat>::operator()()
{
check_outputs(1); check_outputs(1);
auto f = start_tile_function(); auto f = start_tile_function();
...@@ -52,10 +58,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()() ...@@ -52,10 +58,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()()
continue; continue;
} }
std::string sidx{std::to_string(iidx)}; std::string sidx{std::to_string(iidx)};
f.add(builder::Input{op_input(iidx), "I" + sidx}.add_dims("I" + sidx + "_D", 0, dim_count)); f.add(builder::Input{op_input(iidx), "I" + sidx}.add_dims(
"I" + sidx + "_D", 0, dim_count));
f.add(builder::UnaryContraction{"="} f.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"E" + sidx} .set(builder::ContractionOutput{"E" + sidx}
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) { .add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx) for (std::size_t idx = 0; idx < dim_count; ++idx)
{ {
std::ostringstream s; std::ostringstream s;
...@@ -70,19 +78,22 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()() ...@@ -70,19 +78,22 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()()
} }
} }
}) })
.add_indices([&](std::back_insert_iterator<std::list<std::string>> out) { .add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx) for (std::size_t idx = 0; idx < dim_count; ++idx)
{ {
std::ostringstream s; std::ostringstream s;
s << "d" << idx; s << "d" << idx;
if (saw_non_zero_tensor && idx == op().get_concatenation_axis()) if (saw_non_zero_tensor &&
idx == op().get_concatenation_axis())
{ {
s << " + " << offset.str(); s << " + " << offset.str();
} }
out = s.str(); out = s.str();
} }
})) }))
.set(builder::ContractionInput{"I" + sidx}.add_indices("d", 0, dim_count))); .set(builder::ContractionInput{"I" + sidx}.add_indices(
"d", 0, dim_count)));
if (saw_non_zero_tensor) if (saw_non_zero_tensor)
{ {
oexpr << " + "; oexpr << " + ";
...@@ -95,9 +106,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()() ...@@ -95,9 +106,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()()
f.add(builder::Elementwise{"O", oexpr.str()}); f.add(builder::Elementwise{"O", oexpr.str()});
set_output(f.finalize()); set_output(f.finalize());
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::Registration register_concat; Impl<op::Concat>::Registration register_concat;
}
}
}
} }
...@@ -18,21 +18,31 @@ ...@@ -18,21 +18,31 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp" #include "ngraph/runtime/plaidml/plaidml_translate.hpp"
// Convert views a tensor as a new type. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Convert>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// Convert views a tensor as a new type.
template <>
void Impl<op::Convert>::operator()()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(
start_tile_function()
.add(builder::Input{op_input(), "I"}) .add(builder::Input{op_input(), "I"})
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{ .add(builder::Elementwise{
"O", tile_converter("I", to_plaidml(op().get_convert_element_type()))}) "O", tile_converter("I", to_plaidml(op().get_convert_element_type()))})
.finalize()); .finalize());
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::Convert>::Registration register_convert; Impl<op::Convert>::Registration register_convert;
}
}
}
} }
...@@ -50,32 +50,29 @@ namespace ngraph ...@@ -50,32 +50,29 @@ namespace ngraph
std::size_t output_channel_axis_result, std::size_t output_channel_axis_result,
bool rotate_filter); bool rotate_filter);
}; };
}
}
}
template <> template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::Convolution> struct ParentImpl<op::Convolution>
{ {
using Type = ngraph::runtime::plaidml::ConvolutionImpl<ngraph::op::Convolution>; using Type = ConvolutionImpl<op::Convolution>;
}; };
template <> template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::ConvolutionBackpropFilters> struct ParentImpl<op::ConvolutionBackpropFilters>
{ {
using Type = ngraph::runtime::plaidml::ConvolutionImpl<ngraph::op::ConvolutionBackpropFilters>; using Type = ConvolutionImpl<op::ConvolutionBackpropFilters>;
}; };
template <> template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::ConvolutionBackpropData> struct ParentImpl<op::ConvolutionBackpropData>
{ {
using Type = ngraph::runtime::plaidml::ConvolutionImpl<ngraph::op::ConvolutionBackpropData>; using Type = ConvolutionImpl<op::ConvolutionBackpropData>;
}; };
// Convolution implements a standard ML convolultion, with optional striding, padding, and dilation. // Convolution implements a standard ML convolultion, with optional striding, padding, and dilation.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Convolution>::operator()() void Impl<op::Convolution>::operator()()
{ {
this->check_inputs(2); this->check_inputs(2);
this->check_outputs(1); this->check_outputs(1);
...@@ -122,13 +119,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Convolution>::operator()() ...@@ -122,13 +119,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Convolution>::operator()()
.set_lhs(cpf.I_in_body()) .set_lhs(cpf.I_in_body())
.set_rhs(cpf.F_in_body())) .set_rhs(cpf.F_in_body()))
.finalize()); .finalize());
} }
// ConvolutionBackpropFilters implements the derivative of a convolution with respect to its filter // ConvolutionBackpropFilters implements the derivative of a convolution with respect to its filter
// input. // input.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropFilters>::operator()() void Impl<op::ConvolutionBackpropFilters>::operator()()
{ {
this->check_inputs(2); this->check_inputs(2);
this->check_outputs(1); this->check_outputs(1);
...@@ -177,13 +174,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropFilters>::ope ...@@ -177,13 +174,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropFilters>::ope
.set_lhs(cpf.O_in_body()) .set_lhs(cpf.O_in_body())
.set_rhs(cpf.I_in_body())) .set_rhs(cpf.I_in_body()))
.finalize()); .finalize());
} }
// ConvolutionBackpropData implements the derivative of a convolution with respect to its data // ConvolutionBackpropData implements the derivative of a convolution with respect to its data
// input. // input.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropData>::operator()() void Impl<op::ConvolutionBackpropData>::operator()()
{ {
this->check_inputs(2); this->check_inputs(2);
this->check_outputs(1); this->check_outputs(1);
...@@ -232,11 +229,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropData>::operat ...@@ -232,11 +229,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropData>::operat
.set_lhs(cpf.O_in_body()) .set_lhs(cpf.O_in_body())
.set_rhs(cpf.F_in_body())) .set_rhs(cpf.F_in_body()))
.finalize()); .finalize());
} }
template <typename O> template <typename O>
inline void ngraph::runtime::plaidml::ConvolutionImpl<O>::LogConvolution( inline void ConvolutionImpl<O>::LogConvolution(vertexai::plaidml::variable image,
vertexai::plaidml::variable image,
vertexai::plaidml::variable filter, vertexai::plaidml::variable filter,
std::size_t image_dims, std::size_t image_dims,
const Strides& window_movement_strides, const Strides& window_movement_strides,
...@@ -251,7 +247,7 @@ inline void ngraph::runtime::plaidml::ConvolutionImpl<O>::LogConvolution( ...@@ -251,7 +247,7 @@ inline void ngraph::runtime::plaidml::ConvolutionImpl<O>::LogConvolution(
std::size_t batch_axis_result, std::size_t batch_axis_result,
std::size_t output_channel_axis_result, std::size_t output_channel_axis_result,
bool rotate_filter) bool rotate_filter)
{ {
this->check_inputs(2); this->check_inputs(2);
this->check_outputs(1); this->check_outputs(1);
...@@ -271,13 +267,15 @@ inline void ngraph::runtime::plaidml::ConvolutionImpl<O>::LogConvolution( ...@@ -271,13 +267,15 @@ inline void ngraph::runtime::plaidml::ConvolutionImpl<O>::LogConvolution(
NGRAPH_DEBUG << "batch_axis_result: " << batch_axis_result; NGRAPH_DEBUG << "batch_axis_result: " << batch_axis_result;
NGRAPH_DEBUG << "output_channel_axis_result: " << output_channel_axis_result; NGRAPH_DEBUG << "output_channel_axis_result: " << output_channel_axis_result;
NGRAPH_DEBUG << "rotate_filter: " << rotate_filter; NGRAPH_DEBUG << "rotate_filter: " << rotate_filter;
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::Convolution>::Registration register_convolution; Impl<op::Convolution>::Registration register_convolution;
ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropFilters>::Registration Impl<op::ConvolutionBackpropFilters>::Registration
register_convolution_backprop_filters; register_convolution_backprop_filters;
ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropData>::Registration Impl<op::ConvolutionBackpropData>::Registration register_convolution_backprop_data;
register_convolution_backprop_data; }
}
}
} }
...@@ -20,11 +20,17 @@ ...@@ -20,11 +20,17 @@
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Dot is a generalized dot product operation -- scalar-tensor, namespace ngraph
// matrix-vector, and matrix multiplication.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Dot>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// Dot is a generalized dot product operation -- scalar-tensor,
// matrix-vector, and matrix multiplication.
template <>
void Impl<op::Dot>::operator()()
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
...@@ -40,7 +46,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Dot>::operator()() ...@@ -40,7 +46,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Dot>::operator()()
NGRAPH_DEBUG << "l_dim_mac=" << l_dim_mac; NGRAPH_DEBUG << "l_dim_mac=" << l_dim_mac;
NGRAPH_DEBUG << "r_dim_mic=" << r_dim_mic; NGRAPH_DEBUG << "r_dim_mic=" << r_dim_mic;
set_output(start_tile_function() set_output(
start_tile_function()
.add(builder::Input{op_input(0), "L"} .add(builder::Input{op_input(0), "L"}
.add_dims("DL", 1, l_dim_mac + 1) .add_dims("DL", 1, l_dim_mac + 1)
.add_dims("DC", 1, reduce_limit + 1)) .add_dims("DC", 1, reduce_limit + 1))
...@@ -61,9 +68,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Dot>::operator()() ...@@ -61,9 +68,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Dot>::operator()()
.add_indices("dc", 1, reduce_limit + 1) .add_indices("dc", 1, reduce_limit + 1)
.add_indices("dr", r_dim_mic + 1, r_dim_limit + 1))) .add_indices("dr", r_dim_mic + 1, r_dim_limit + 1)))
.finalize()); .finalize());
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::Dot>::Registration register_dot; Impl<op::Dot>::Registration register_dot;
}
}
}
} }
...@@ -19,10 +19,16 @@ ...@@ -19,10 +19,16 @@
#include "ngraph/runtime/plaidml/plaidml_compiler.hpp" #include "ngraph/runtime/plaidml/plaidml_compiler.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// FunctionCall invokes a sub-function. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::FunctionCall>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// FunctionCall invokes a sub-function.
template <>
void Impl<op::FunctionCall>::operator()()
{
Build b; Build b;
build()->compiler->build(op().get_functions()[0], &b); build()->compiler->build(op().get_functions()[0], &b);
vertexai::plaidml::function f{b.composer}; vertexai::plaidml::function f{b.composer};
...@@ -30,7 +36,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::FunctionCall>::operator()() ...@@ -30,7 +36,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::FunctionCall>::operator()()
for (std::size_t idx = 0; idx < op().get_input_size(); ++idx) for (std::size_t idx = 0; idx < op().get_input_size(); ++idx)
{ {
auto* oitv = op().get_inputs()[idx].get_output().get_tensor_ptr().get(); auto* oitv = op().get_inputs()[idx].get_output().get_tensor_ptr().get();
auto* iitv = b.func->get_parameters()[idx]->get_outputs()[0].get_tensor_ptr().get(); auto* iitv =
b.func->get_parameters()[idx]->get_outputs()[0].get_tensor_ptr().get();
inputs.emplace_back(b.input_names.at(iitv), build()->bindings.at(oitv).var); inputs.emplace_back(b.input_names.at(iitv), build()->bindings.at(oitv).var);
} }
vertexai::plaidml::application app{f.apply(inputs)}; vertexai::plaidml::application app{f.apply(inputs)};
...@@ -39,9 +46,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::FunctionCall>::operator()() ...@@ -39,9 +46,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::FunctionCall>::operator()()
auto* iotv = b.func->get_results()[idx]->get_output_tensor_ptr().get(); auto* iotv = b.func->get_results()[idx]->get_output_tensor_ptr().get();
set_output(idx, app.get_output(b.output_names[iotv])); set_output(idx, app.get_output(b.output_names[iotv]));
} }
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::FunctionCall>::Registration register_function_call; Impl<op::FunctionCall>::Registration register_function_call;
}
}
}
} }
...@@ -28,10 +28,16 @@ ...@@ -28,10 +28,16 @@
namespace vp = vertexai::plaidml; namespace vp = vertexai::plaidml;
// Broadcast broadcasts a tensor to a wider shape. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Broadcast>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// Broadcast broadcasts a tensor to a wider shape.
template <>
void Impl<op::Broadcast>::operator()()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
...@@ -57,15 +63,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Broadcast>::operator()() ...@@ -57,15 +63,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Broadcast>::operator()()
start_tile_function() start_tile_function()
.add(builder::Input{op_input(0), "I"}.add_rdims("D", in_dim_limit, 0)) .add(builder::Input{op_input(0), "I"}.add_rdims("D", in_dim_limit, 0))
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::UnaryContraction{"="} .add(
.set(builder::ContractionOutput{"O"} builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"O"}
.add_rindices("o", out_dim_limit, 0) .add_rindices("o", out_dim_limit, 0)
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) { .add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < out_dim_limit; ++idx) for (std::size_t idx = 0; idx < out_dim_limit; ++idx)
{ {
if (op().get_broadcast_axes().count(idx)) if (op().get_broadcast_axes().count(idx))
{ {
out = std::to_string(op().get_broadcast_shape()[idx]); out = std::to_string(
op().get_broadcast_shape()[idx]);
} }
else else
{ {
...@@ -81,12 +91,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Broadcast>::operator()() ...@@ -81,12 +91,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Broadcast>::operator()()
} }
}))) })))
.finalize()); .finalize());
} }
// Constant fills in a tensor constant. // Constant fills in a tensor constant.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Constant>::operator()() void Impl<op::Constant>::operator()()
{ {
check_inputs(0); check_inputs(0);
check_outputs(1); check_outputs(1);
...@@ -105,48 +115,51 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Constant>::operator()() ...@@ -105,48 +115,51 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Constant>::operator()()
switch (to_plaidml(op().get_element_type())) switch (to_plaidml(op().get_element_type()))
{ {
case PLAIDML_DATA_BOOLEAN: case PLAIDML_DATA_BOOLEAN:
set_output(static_cast<std::int64_t>(*static_cast<const char*>(op().get_data_ptr()))); set_output(static_cast<std::int64_t>(
*static_cast<const char*>(op().get_data_ptr())));
return; return;
case PLAIDML_DATA_INT8: case PLAIDML_DATA_INT8:
set_output( set_output(static_cast<std::int64_t>(
static_cast<std::int64_t>(*static_cast<const std::int8_t*>(op().get_data_ptr()))); *static_cast<const std::int8_t*>(op().get_data_ptr())));
return; return;
case PLAIDML_DATA_INT16: case PLAIDML_DATA_INT16:
set_output( set_output(static_cast<std::int64_t>(
static_cast<std::int64_t>(*static_cast<const std::int16_t*>(op().get_data_ptr()))); *static_cast<const std::int16_t*>(op().get_data_ptr())));
return; return;
case PLAIDML_DATA_INT32: case PLAIDML_DATA_INT32:
set_output( set_output(static_cast<std::int64_t>(
static_cast<std::int64_t>(*static_cast<const std::int32_t*>(op().get_data_ptr()))); *static_cast<const std::int32_t*>(op().get_data_ptr())));
return; return;
case PLAIDML_DATA_INT64: case PLAIDML_DATA_INT64:
set_output(*static_cast<const std::int64_t*>(op().get_data_ptr())); set_output(*static_cast<const std::int64_t*>(op().get_data_ptr()));
return; return;
case PLAIDML_DATA_UINT8: case PLAIDML_DATA_UINT8:
set_output( set_output(static_cast<std::int64_t>(
static_cast<std::int64_t>(*static_cast<const std::uint8_t*>(op().get_data_ptr()))); *static_cast<const std::uint8_t*>(op().get_data_ptr())));
return; return;
case PLAIDML_DATA_UINT16: case PLAIDML_DATA_UINT16:
set_output( set_output(static_cast<std::int64_t>(
static_cast<std::int64_t>(*static_cast<const std::uint16_t*>(op().get_data_ptr()))); *static_cast<const std::uint16_t*>(op().get_data_ptr())));
return; return;
case PLAIDML_DATA_UINT32: case PLAIDML_DATA_UINT32:
set_output( set_output(static_cast<std::int64_t>(
static_cast<std::int64_t>(*static_cast<const std::uint32_t*>(op().get_data_ptr()))); *static_cast<const std::uint32_t*>(op().get_data_ptr())));
return; return;
case PLAIDML_DATA_UINT64: case PLAIDML_DATA_UINT64:
set_output( set_output(static_cast<std::int64_t>(
static_cast<std::int64_t>(*static_cast<const std::uint64_t*>(op().get_data_ptr()))); *static_cast<const std::uint64_t*>(op().get_data_ptr())));
return; return;
case PLAIDML_DATA_FLOAT16: case PLAIDML_DATA_FLOAT16:
set_output(static_cast<double>( set_output(static_cast<double>(
static_cast<float>(*static_cast<const half*>(op().get_data_ptr())))); static_cast<float>(*static_cast<const half*>(op().get_data_ptr()))));
return; return;
case PLAIDML_DATA_FLOAT32: case PLAIDML_DATA_FLOAT32:
set_output(static_cast<double>(*static_cast<const float*>(op().get_data_ptr()))); set_output(
static_cast<double>(*static_cast<const float*>(op().get_data_ptr())));
return; return;
case PLAIDML_DATA_FLOAT64: case PLAIDML_DATA_FLOAT64:
set_output(static_cast<double>(*static_cast<const double*>(op().get_data_ptr()))); set_output(
static_cast<double>(*static_cast<const double*>(op().get_data_ptr())));
return; return;
default: break; default: break;
} }
...@@ -163,22 +176,22 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Constant>::operator()() ...@@ -163,22 +176,22 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Constant>::operator()()
} }
set_output(tensor); set_output(tensor);
} }
// GetOutputElement pipes one of its N inputs to its output. // GetOutputElement pipes one of its N inputs to its output.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::GetOutputElement>::operator()() void Impl<op::GetOutputElement>::operator()()
{ {
check_inputs_ge(op().get_n() + 1); check_inputs_ge(op().get_n() + 1);
check_outputs(1); check_outputs(1);
set_output(op_input(op().get_n())); set_output(op_input(op().get_n()));
} }
// Pad adds interior and exterior padding to a tensor. // Pad adds interior and exterior padding to a tensor.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()() void Impl<op::Pad>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
...@@ -214,7 +227,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()() ...@@ -214,7 +227,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
auto out_dsize = [&](std::size_t idx) { auto out_dsize = [&](std::size_t idx) {
std::ostringstream s; std::ostringstream s;
std::size_t total_pad = op().get_padding_below().at(idx) + op().get_padding_above().at(idx); std::size_t total_pad =
op().get_padding_below().at(idx) + op().get_padding_above().at(idx);
std::size_t in_dsize = op().get_input_shape(0).at(idx); std::size_t in_dsize = op().get_input_shape(0).at(idx);
if (in_dsize) if (in_dsize)
{ {
...@@ -267,40 +281,48 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()() ...@@ -267,40 +281,48 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
if (!any_zero_dims) if (!any_zero_dims)
{ {
f.add(builder::Input{op_input(0), "I"}.add_dims("DI", 1, dim_limit + 1)) f.add(builder::Input{op_input(0), "I"}.add_dims("DI", 1, dim_limit + 1))
.add(builder::UnaryContraction{"="} .add(
.set(builder::ContractionOutput{"P"} builder::UnaryContraction{"="}
.add_indices( .set(
[&](std::back_insert_iterator<std::list<std::string>> out) { builder::ContractionOutput{"P"}
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx) for (std::size_t idx = 0; idx < dim_limit; ++idx)
{ {
out = out_didx(idx); out = out_didx(idx);
} }
}) })
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) { .add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx) for (std::size_t idx = 0; idx < dim_limit; ++idx)
{ {
out = out_dsize(idx); out = out_dsize(idx);
} }
})) }))
.set(builder::ContractionInput{"I"}.add_indices("d", 1, dim_limit + 1))) .set(builder::ContractionInput{"I"}.add_indices(
"d", 1, dim_limit + 1)))
.add(builder::Elementwise{"T", "1"}) .add(builder::Elementwise{"T", "1"})
.add(builder::UnaryContraction{"="} .add(
.set(builder::ContractionOutput{"F"} builder::UnaryContraction{"="}
.add_indices( .set(
[&](std::back_insert_iterator<std::list<std::string>> out) { builder::ContractionOutput{"F"}
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx) for (std::size_t idx = 0; idx < dim_limit; ++idx)
{ {
out = out_didx(idx); out = out_didx(idx);
} }
}) })
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) { .add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx) for (std::size_t idx = 0; idx < dim_limit; ++idx)
{ {
out = out_dsize(idx); out = out_dsize(idx);
} }
})) }))
.set(builder::ContractionInput{"T"}) .set(builder::ContractionInput{"T"})
.add_constraints([&](std::back_insert_iterator<std::list<std::string>> out) { .add_constraints(
[&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx) for (std::size_t idx = 0; idx < dim_limit; ++idx)
{ {
out = flag_constraints(idx); out = flag_constraints(idx);
...@@ -313,7 +335,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()() ...@@ -313,7 +335,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
f.add(builder::UnaryContraction{"="} f.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"} .set(builder::ContractionOutput{"O"}
.add_indices("d", 0, dim_limit) .add_indices("d", 0, dim_limit)
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) { .add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx) for (std::size_t idx = 0; idx < dim_limit; ++idx)
{ {
out = out_dsize(idx); out = out_dsize(idx);
...@@ -323,12 +346,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()() ...@@ -323,12 +346,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
} }
set_output(f.finalize()); set_output(f.finalize());
} }
// Reshape reshapes an input tensor. // Reshape reshapes an input tensor.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()() void Impl<op::Reshape>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
...@@ -351,13 +374,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()() ...@@ -351,13 +374,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()()
.add(builder::UnaryContraction{"="} .add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"} .set(builder::ContractionOutput{"O"}
.add_indices("d", 0, out_shape.size()) .add_indices("d", 0, out_shape.size())
.add_dims( .add_dims([&](
[&](std::back_insert_iterator<std::list<std::string>> out) { std::back_insert_iterator<std::list<std::string>>
std::transform( out) {
out_shape.begin(), std::transform(out_shape.begin(),
out_shape.end(), out_shape.end(),
out, out,
[](std::size_t sz) { return std::to_string(sz); }); [](std::size_t sz) {
return std::to_string(sz);
});
})) }))
.set(builder::ContractionInput{"I"})) .set(builder::ContractionInput{"I"}))
.finalize()); .finalize());
...@@ -374,28 +399,31 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()() ...@@ -374,28 +399,31 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()()
// it's also rearranging the elements of the input tensor. This is pretty easy to // it's also rearranging the elements of the input tensor. This is pretty easy to
// handle with a contraction. // handle with a contraction.
src = src = start_tile_function()
start_tile_function()
.add(builder::Input{src, "I"}.add_dims("D", 1, dim_limit + 1)) .add(builder::Input{src, "I"}.add_dims("D", 1, dim_limit + 1))
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add( .add(builder::UnaryContraction{"="}
builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"} .set(builder::ContractionOutput{"O"}
.add_indices([&]( .add_indices([&](std::back_insert_iterator<
std::back_insert_iterator<std::list<std::string>> out) { std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx) for (std::size_t idx = 0; idx < dim_limit;
++idx)
{ {
out = "d" + std::to_string(input_order[idx] + 1); out = "d" + std::to_string(
input_order[idx] + 1);
} }
}) })
.add_dims([&]( .add_dims([&](std::back_insert_iterator<
std::back_insert_iterator<std::list<std::string>> out) { std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx) for (std::size_t idx = 0; idx < dim_limit;
++idx)
{ {
out = "D" + std::to_string(input_order[idx] + 1); out = "D" + std::to_string(
input_order[idx] + 1);
} }
})) }))
.set(builder::ContractionInput{"I"}.add_indices("d", 1, dim_limit + 1))) .set(builder::ContractionInput{"I"}.add_indices(
"d", 1, dim_limit + 1)))
.finalize(); .finalize();
break; break;
} }
...@@ -414,12 +442,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()() ...@@ -414,12 +442,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise("O", reshape_expr.str())) .add(builder::Elementwise("O", reshape_expr.str()))
.finalize()); .finalize());
} }
// Select conditionally selects elements from input tensors. // Select conditionally selects elements from input tensors.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Select>::operator()() void Impl<op::Select>::operator()()
{ {
check_inputs(3); check_inputs(3);
check_outputs(1); check_outputs(1);
...@@ -430,26 +458,28 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Select>::operator()() ...@@ -430,26 +458,28 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Select>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "C ? T : F"}) .add(builder::Elementwise{"O", "C ? T : F"})
.finalize()); .finalize());
} }
// Used by nGraph for bprop graph generation, no-op as a kernel // Used by nGraph for bprop graph generation, no-op as a kernel
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::StopGradient>::operator()() void Impl<op::StopGradient>::operator()()
{ {
set_output(start_tile_function() set_output(start_tile_function()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "0"}) .add(builder::Elementwise{"O", "0"})
.finalize()); .finalize());
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::Broadcast>::Registration register_broadcast; Impl<op::Broadcast>::Registration register_broadcast;
ngraph::runtime::plaidml::Impl<ngraph::op::Constant>::Registration register_constant; Impl<op::Constant>::Registration register_constant;
ngraph::runtime::plaidml::Impl<ngraph::op::GetOutputElement>::Registration Impl<op::GetOutputElement>::Registration register_get_output_element;
register_get_output_element; Impl<op::Pad>::Registration register_pad;
ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::Registration register_pad; Impl<op::Reshape>::Registration register_reshape;
ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::Registration register_reshape; Impl<op::Select>::Registration register_select;
ngraph::runtime::plaidml::Impl<ngraph::op::Select>::Registration register_select; Impl<op::StopGradient>::Registration register_stop_gradient;
ngraph::runtime::plaidml::Impl<ngraph::op::StopGradient>::Registration register_stop_gradient; }
}
}
} }
...@@ -36,13 +36,10 @@ namespace ngraph ...@@ -36,13 +36,10 @@ namespace ngraph
void build_index_reduction(const char* agg_op); void build_index_reduction(const char* agg_op);
}; };
}
}
}
template <typename O> template <typename O>
void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(const char* agg_op) void IndexReductionImpl<O>::build_index_reduction(const char* agg_op)
{ {
this->check_inputs(1); this->check_inputs(1);
this->check_outputs(1); this->check_outputs(1);
...@@ -56,16 +53,20 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons ...@@ -56,16 +53,20 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add( // Compute the maxes along the specified axis in the input .add( // Compute the maxes along the specified axis in the input
builder::UnaryContraction{agg_op} builder::UnaryContraction{agg_op}
.set(builder::ContractionOutput{"SelVal"} .set(
builder::ContractionOutput{"SelVal"}
.add_indices([&]( .add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) { std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx = 0; idx < dim_limit; ++idx) for (auto idx = 0; idx < dim_limit; ++idx)
{ {
out = (idx == this->op().get_reduction_axis() ? "rd" : "d") + out =
(idx == this->op().get_reduction_axis() ? "rd"
: "d") +
std::to_string(idx); std::to_string(idx);
} }
}) })
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) { .add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx = 0; idx < dim_limit; ++idx) for (auto idx = 0; idx < dim_limit; ++idx)
{ {
if (idx == this->op().get_reduction_axis()) if (idx == this->op().get_reduction_axis())
...@@ -82,13 +83,14 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons ...@@ -82,13 +83,14 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
.add( // Compare the input against the (broadcasted) max values, and select the indices .add( // Compare the input against the (broadcasted) max values, and select the indices
// where the max val occurs // where the max val occurs
builder::Elementwise{"SelValIdxs", builder::Elementwise{"SelValIdxs",
"I == SelVal ? index(I, " + reduction_axis_str + ") : D" + "I == SelVal ? index(I, " + reduction_axis_str +
reduction_axis_str}) ") : D" + reduction_axis_str})
.add( // Select the maximum index .add( // Select the maximum index
builder::UnaryContraction{"<"} builder::UnaryContraction{"<"}
.set(builder::ContractionOutput{"SelIdx"} .set(
.add_indices( builder::ContractionOutput{"SelIdx"}
[&](std::back_insert_iterator<std::list<std::string>> out) { .add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx = 0; idx < dim_limit; ++idx) for (auto idx = 0; idx < dim_limit; ++idx)
{ {
if (idx != this->op().get_reduction_axis()) if (idx != this->op().get_reduction_axis())
...@@ -97,7 +99,8 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons ...@@ -97,7 +99,8 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
} }
} }
}) })
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) { .add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx = 0; idx < dim_limit; ++idx) for (auto idx = 0; idx < dim_limit; ++idx)
{ {
if (idx != this->op().get_reduction_axis()) if (idx != this->op().get_reduction_axis())
...@@ -106,41 +109,45 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons ...@@ -106,41 +109,45 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
} }
} }
})) }))
.set(builder::ContractionInput{"SelValIdxs"}.add_indices("d", 0, dim_limit))) .set(builder::ContractionInput{"SelValIdxs"}.add_indices(
"d", 0, dim_limit)))
.add( // Convert to the requested output element type (if any) .add( // Convert to the requested output element type (if any)
builder::Elementwise{"O", builder::Elementwise{
tile_converter("SelIdx", this->op().get_index_element_type())}) "O", tile_converter("SelIdx", this->op().get_index_element_type())})
.finalize()); .finalize());
} }
template <> template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::ArgMax> struct ParentImpl<op::ArgMax>
{ {
using Type = IndexReductionImpl<ngraph::op::ArgMax>; using Type = IndexReductionImpl<op::ArgMax>;
}; };
template <> template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::ArgMin> struct ParentImpl<op::ArgMin>
{ {
using Type = IndexReductionImpl<ngraph::op::ArgMin>; using Type = IndexReductionImpl<op::ArgMin>;
}; };
// ArgMax computes the maximum index along a tensor axis. // ArgMax computes the maximum index along a tensor axis.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::ArgMax>::operator()() void Impl<op::ArgMax>::operator()()
{ {
build_index_reduction(">"); build_index_reduction(">");
} }
// ArgMin computes the minimum index along a tensor axis. // ArgMin computes the minimum index along a tensor axis.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::ArgMin>::operator()() void Impl<op::ArgMin>::operator()()
{ {
build_index_reduction("<"); build_index_reduction("<");
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::ArgMax>::Registration register_argmax; Impl<op::ArgMax>::Registration register_argmax;
ngraph::runtime::plaidml::Impl<ngraph::op::ArgMin>::Registration register_argmin; Impl<op::ArgMin>::Registration register_argmin;
}
}
}
} }
...@@ -20,10 +20,16 @@ ...@@ -20,10 +20,16 @@
namespace vp = vertexai::plaidml; namespace vp = vertexai::plaidml;
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Parameter>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder.
template <>
void Impl<op::Parameter>::operator()()
{
check_inputs(0); check_inputs(0);
check_outputs(1); check_outputs(1);
vp::placeholder ph{build()->io_dim_override ? build()->io_dim_override_count vp::placeholder ph{build()->io_dim_override ? build()->io_dim_override_count
...@@ -33,22 +39,25 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Parameter>::operator()() ...@@ -33,22 +39,25 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Parameter>::operator()()
build()->bindings.emplace(tv, TensorInfo{ph, TensorContents::DATA}); build()->bindings.emplace(tv, TensorInfo{ph, TensorContents::DATA});
build()->composer.input(name, ph); build()->composer.input(name, ph);
build()->input_names.emplace(tv, std::move(name)); build()->input_names.emplace(tv, std::move(name));
} }
// Result binds a PlaidML variable to a composed function output. // Result binds a PlaidML variable to a composed function output.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Result>::operator()() void Impl<op::Result>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
std::string name = std::string{"O"} + std::to_string(build()->output_names.size()); std::string name = std::string{"O"} + std::to_string(build()->output_names.size());
descriptor::Tensor* tv = op().get_output_tensor_ptr().get(); descriptor::Tensor* tv = op().get_output_tensor_ptr().get();
build()->composer.output(name, op_input()); build()->composer.output(name, op_input());
build()->output_names.emplace(tv, std::move(name)); build()->output_names.emplace(tv, std::move(name));
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::Parameter>::Registration register_parameter; Impl<op::Parameter>::Registration register_parameter;
ngraph::runtime::plaidml::Impl<ngraph::op::Result>::Registration register_result; Impl<op::Result>::Registration register_result;
}
}
}
} }
...@@ -17,21 +17,29 @@ ...@@ -17,21 +17,29 @@
#include "ngraph/op/lrn.hpp" #include "ngraph/op/lrn.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// LRN implements Local Response Normalization namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::LRN>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// LRN implements Local Response Normalization
template <>
void Impl<op::LRN>::operator()()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
auto dim_limit = op().get_inputs()[0].get_shape().size(); auto dim_limit = op().get_inputs()[0].get_shape().size();
auto rank = dim_limit - 2; auto rank = dim_limit - 2;
auto distance = op().get_nsize() / 2; auto distance = op().get_nsize() / 2;
std::ostringstream div_expr; std::ostringstream div_expr;
div_expr << "I / pow(" << op().get_bias() << ".0 + ((" << op().get_alpha() << ".0 / " div_expr << "I / pow(" << op().get_bias() << ".0 + ((" << op().get_alpha()
<< op().get_nsize() << ".0) * S), " << op().get_beta() << ".0)"; << ".0 / " << op().get_nsize() << ".0) * S), " << op().get_beta() << ".0)";
set_output( set_output(
start_tile_function() start_tile_function()
.add(builder::Input{op_input(), "I"}.add_dims({"N", "C"}).add_dims("D", 0, rank)) .add(builder::Input{op_input(), "I"}
.add_dims({"N", "C"})
.add_dims("D", 0, rank))
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"ISQ", "I * I"}) .add(builder::Elementwise{"ISQ", "I * I"})
.add(builder::UnaryContraction{"+"} .add(builder::UnaryContraction{"+"}
...@@ -43,14 +51,18 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::LRN>::operator()() ...@@ -43,14 +51,18 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::LRN>::operator()()
.set(builder::ContractionInput{"ISQ"} .set(builder::ContractionInput{"ISQ"}
.add_indices({"n", "c + z - " + std::to_string(distance)}) .add_indices({"n", "c + z - " + std::to_string(distance)})
.add_indices("d", 0, rank)) .add_indices("d", 0, rank))
.add_constraints([&](std::back_insert_iterator<std::list<std::string>> out) { .add_constraints(
[&](std::back_insert_iterator<std::list<std::string>> out) {
out = "z < " + std::to_string(op().get_nsize()); out = "z < " + std::to_string(op().get_nsize());
})) }))
.add(builder::Elementwise{"O", div_expr.str()}) .add(builder::Elementwise{"O", div_expr.str()})
.finalize()); .finalize());
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::LRN>::Registration register_local_response_norm; Impl<op::LRN>::Registration register_local_response_norm;
}
}
}
} }
...@@ -19,10 +19,16 @@ ...@@ -19,10 +19,16 @@
#include "ngraph/op/or.hpp" #include "ngraph/op/or.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// And performs a simple elementwise logical and. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::And>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// And performs a simple elementwise logical and.
template <>
void Impl<op::And>::operator()()
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -32,12 +38,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::And>::operator()() ...@@ -32,12 +38,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::And>::operator()()
.add(builder::Elementwise{"C", "A ? B : A"}) .add(builder::Elementwise{"C", "A ? B : A"})
.finalize(), .finalize(),
TensorContents::LOGICAL); TensorContents::LOGICAL);
} }
// Not performs a simple elementwise logical not. // Not performs a simple elementwise logical not.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Not>::operator()() void Impl<op::Not>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -46,12 +52,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Not>::operator()() ...@@ -46,12 +52,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Not>::operator()()
.add(builder::Elementwise{"O", "cmp_eq(I, 0)"}) .add(builder::Elementwise{"O", "cmp_eq(I, 0)"})
.finalize(), .finalize(),
TensorContents::LOGICAL); TensorContents::LOGICAL);
} }
// Or performs a simple elementwise logical or. // Or performs a simple elementwise logical or.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Or>::operator()() void Impl<op::Or>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -61,11 +67,14 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Or>::operator()() ...@@ -61,11 +67,14 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Or>::operator()()
.add(builder::Elementwise{"C", "A ? A : B"}) .add(builder::Elementwise{"C", "A ? A : B"})
.finalize(), .finalize(),
TensorContents::LOGICAL); TensorContents::LOGICAL);
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::And>::Registration register_and; Impl<op::And>::Registration register_and;
ngraph::runtime::plaidml::Impl<ngraph::op::Not>::Registration register_not; Impl<op::Not>::Registration register_not;
ngraph::runtime::plaidml::Impl<ngraph::op::Or>::Registration register_or; Impl<op::Or>::Registration register_or;
}
}
}
} }
...@@ -20,10 +20,16 @@ ...@@ -20,10 +20,16 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp" #include "ngraph/runtime/plaidml/plaidml_translate.hpp"
// OneHot performs one-hot encoding along the requested axis. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::OneHot>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// OneHot performs one-hot encoding along the requested axis.
template <>
void Impl<op::OneHot>::operator()()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
...@@ -68,9 +74,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::OneHot>::operator()() ...@@ -68,9 +74,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::OneHot>::operator()()
.add(builder::Input{op_input(), "I"}.add_dims("D", 0, in_shape.size())) .add(builder::Input{op_input(), "I"}.add_dims("D", 0, in_shape.size()))
.add(builder::Input{static_cast<std::int64_t>(0), "Zero"}) .add(builder::Input{static_cast<std::int64_t>(0), "Zero"})
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::UnaryContraction{"="} .add(
.set(builder::ContractionOutput{"ZS"} builder::UnaryContraction{"="}
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) { .set(
builder::ContractionOutput{"ZS"}
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < out_shape.size(); ++idx) for (std::size_t idx = 0; idx < out_shape.size(); ++idx)
{ {
if (idx == op().get_one_hot_axis()) if (idx == op().get_one_hot_axis())
...@@ -85,15 +94,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::OneHot>::operator()() ...@@ -85,15 +94,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::OneHot>::operator()()
}) })
.add_indices("d", 0, out_shape.size())) .add_indices("d", 0, out_shape.size()))
.set(builder::ContractionInput{"Zero"})) .set(builder::ContractionInput{"Zero"}))
.add(builder::Elementwise{"Idx", .add(builder::Elementwise{
"index(ZS, " + std::to_string(op().get_one_hot_axis()) + ")"}) "Idx", "index(ZS, " + std::to_string(op().get_one_hot_axis()) + ")"})
.add(builder::Elementwise{"IS", "reshape(I, " + in_reshape.str() + ")"}) .add(builder::Elementwise{"IS", "reshape(I, " + in_reshape.str() + ")"})
.add(builder::Elementwise{"OV", "IS == Idx ? 1 : 0"}) .add(builder::Elementwise{"OV", "IS == Idx ? 1 : 0"})
.add(builder::Elementwise{"O", tile_converter("OV", op().get_element_type())}) .add(builder::Elementwise{"O",
tile_converter("OV", op().get_element_type())})
.finalize()); .finalize());
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::OneHot>::Registration register_one_hot; Impl<op::OneHot>::Registration register_one_hot;
}
}
}
} }
...@@ -20,10 +20,16 @@ ...@@ -20,10 +20,16 @@
#include "ngraph/runtime/plaidml/plaidml_convpool_formatter.hpp" #include "ngraph/runtime/plaidml/plaidml_convpool_formatter.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// AvgPool implements a batch average pooling operation. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPool>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// AvgPool implements a batch average pooling operation.
template <>
void Impl<op::AvgPool>::operator()()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
...@@ -92,12 +98,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPool>::operator()() ...@@ -92,12 +98,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPool>::operator()()
f.add(cpf.PoolContraction()).add(builder::Elementwise{"O", "S / Count"}); f.add(cpf.PoolContraction()).add(builder::Elementwise{"O", "S / Count"});
set_output(f.finalize()); set_output(f.finalize());
} }
// MaxPool implements a batch max pooling operation. // MaxPool implements a batch max pooling operation.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::MaxPool>::operator()() void Impl<op::MaxPool>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
...@@ -156,11 +162,11 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::MaxPool>::operator()() ...@@ -156,11 +162,11 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::MaxPool>::operator()()
.add(cpf.O_out_header()) .add(cpf.O_out_header())
.add(cpf.PoolContraction()) .add(cpf.PoolContraction())
.finalize()); .finalize());
} }
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPoolBackprop>::operator()() void Impl<op::AvgPoolBackprop>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
...@@ -174,7 +180,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPoolBackprop>::operator()() ...@@ -174,7 +180,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPoolBackprop>::operator()()
if (include_padding) if (include_padding)
{ {
throw std::runtime_error("Include padding in average not yet implemented in PlaidML"); throw std::runtime_error(
"Include padding in average not yet implemented in PlaidML");
} }
ngraph::CoordinateDiff pad_above; ngraph::CoordinateDiff pad_above;
...@@ -229,18 +236,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPoolBackprop>::operator()() ...@@ -229,18 +236,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPoolBackprop>::operator()()
{ {
std::ostringstream s; std::ostringstream s;
s << "XI" << i - 2; s << "XI" << i - 2;
ret.add(builder::Input{static_cast<std::int64_t>(forward_arg_shape[i]), s.str()}); ret.add(
builder::Input{static_cast<std::int64_t>(forward_arg_shape[i]), s.str()});
} }
set_output(ret.add(cpf.Broadcast_Ones()) set_output(ret.add(cpf.Broadcast_Ones())
.add(cpf.Count()) .add(cpf.Count())
.add(builder::Elementwise{"S", "DO / Count"}) .add(builder::Elementwise{"S", "DO / Count"})
.add(cpf.PoolContraction()) .add(cpf.PoolContraction())
.finalize()); .finalize());
} }
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::MaxPoolBackprop>::operator()() void Impl<op::MaxPoolBackprop>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
...@@ -299,14 +307,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::MaxPoolBackprop>::operator()() ...@@ -299,14 +307,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::MaxPoolBackprop>::operator()()
.add(cpf.PoolContraction()) .add(cpf.PoolContraction())
.add(cpf.PoolDerivContraction()) .add(cpf.PoolDerivContraction())
.finalize()); .finalize());
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::AvgPool>::Registration register_avg_pool; Impl<op::AvgPool>::Registration register_avg_pool;
ngraph::runtime::plaidml::Impl<ngraph::op::MaxPool>::Registration register_max_pool; Impl<op::MaxPool>::Registration register_max_pool;
ngraph::runtime::plaidml::Impl<ngraph::op::AvgPoolBackprop>::Registration Impl<op::AvgPoolBackprop>::Registration register_avg_pool_backprop;
register_avg_pool_backprop; Impl<op::MaxPoolBackprop>::Registration register_max_pool_backprop;
ngraph::runtime::plaidml::Impl<ngraph::op::MaxPoolBackprop>::Registration }
register_max_pool_backprop; }
}
} }
...@@ -42,13 +42,10 @@ namespace ngraph ...@@ -42,13 +42,10 @@ namespace ngraph
void build_reduction(const char* agg_op); void build_reduction(const char* agg_op);
}; };
}
}
}
template <typename O> template <typename O>
void ngraph::runtime::plaidml::ReductionImpl<O>::build_reduction(const char* agg_op) void ReductionImpl<O>::build_reduction(const char* agg_op)
{ {
this->check_inputs(1); this->check_inputs(1);
this->check_outputs(1); this->check_outputs(1);
...@@ -68,81 +65,86 @@ void ngraph::runtime::plaidml::ReductionImpl<O>::build_reduction(const char* agg ...@@ -68,81 +65,86 @@ void ngraph::runtime::plaidml::ReductionImpl<O>::build_reduction(const char* agg
this->start_tile_function() this->start_tile_function()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Input{this->op_input(0), "I"}.add_dims("D", 1, in_dim_limit + 1)) .add(builder::Input{this->op_input(0), "I"}.add_dims(
.add(builder::UnaryContraction{agg_op} "D", 1, in_dim_limit + 1))
.set(builder::ContractionOutput{"O"} .add(
.add_indices( builder::UnaryContraction{agg_op}
[&](std::back_insert_iterator<std::list<std::string>> out) { .set(
builder::ContractionOutput{"O"}
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < out_idxs.size(); ++idx) for (std::size_t idx = 0; idx < out_idxs.size(); ++idx)
{ {
out = "d" + std::to_string(out_idxs[idx] + 1); out = "d" + std::to_string(out_idxs[idx] + 1);
} }
}) })
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) { .add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < out_idxs.size(); ++idx) for (std::size_t idx = 0; idx < out_idxs.size(); ++idx)
{ {
out = "D" + std::to_string(out_idxs[idx] + 1); out = "D" + std::to_string(out_idxs[idx] + 1);
} }
})) }))
.set(builder::ContractionInput{"I"}.add_indices("d", 1, in_dim_limit + 1))) .set(builder::ContractionInput{"I"}.add_indices(
"d", 1, in_dim_limit + 1)))
.finalize()); .finalize());
} }
template <> template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::Max> struct ParentImpl<op::Max>
{ {
using Type = ngraph::runtime::plaidml::ReductionImpl<ngraph::op::Max>; using Type = ReductionImpl<op::Max>;
}; };
template <> template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::Min> struct ParentImpl<op::Min>
{ {
using Type = ngraph::runtime::plaidml::ReductionImpl<ngraph::op::Min>; using Type = ReductionImpl<op::Min>;
}; };
template <> template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::Product> struct ParentImpl<op::Product>
{ {
using Type = ngraph::runtime::plaidml::ReductionImpl<ngraph::op::Product>; using Type = ReductionImpl<op::Product>;
}; };
template <> template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::Reduce> struct ParentImpl<op::Reduce>
{ {
using Type = ngraph::runtime::plaidml::ReductionImpl<ngraph::op::Reduce>; using Type = ReductionImpl<op::Reduce>;
}; };
template <> template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::Sum> struct ParentImpl<op::Sum>
{ {
using Type = ngraph::runtime::plaidml::ReductionImpl<ngraph::op::Sum>; using Type = ReductionImpl<op::Sum>;
}; };
// Max reduces a tensor, taking the maximum along the specified axes. // Max reduces a tensor, taking the maximum along the specified axes.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Max>::operator()() void Impl<op::Max>::operator()()
{ {
build_reduction(">"); build_reduction(">");
} }
// Min reduces a tensor, taking the minimum along the specified axes. // Min reduces a tensor, taking the minimum along the specified axes.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Min>::operator()() void Impl<op::Min>::operator()()
{ {
build_reduction("<"); build_reduction("<");
} }
// Min reduces a tensor, taking the product along the specified axes. // Min reduces a tensor, taking the product along the specified axes.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Product>::operator()() void Impl<op::Product>::operator()()
{ {
build_reduction("*"); build_reduction("*");
} }
// Reduce reduces a tensor with an arbitrary user-supplied reduction operation. // Reduce reduces a tensor with an arbitrary user-supplied reduction operation.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()() void Impl<op::Reduce>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
...@@ -183,10 +185,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()() ...@@ -183,10 +185,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
start_tile_function() start_tile_function()
.add(builder::Input{op_input(1), "I"}) .add(builder::Input{op_input(1), "I"})
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::UnaryContraction{"="} .add(
.set(builder::ContractionOutput{"O"} builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"O"}
.add_indices("d", 0, agg_dim_limit) .add_indices("d", 0, agg_dim_limit)
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) { .add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx = 0; idx < agg_dim_limit; ++idx) for (auto idx = 0; idx < agg_dim_limit; ++idx)
{ {
out = "1"; out = "1";
...@@ -205,9 +210,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()() ...@@ -205,9 +210,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::UnaryContraction{"="} .add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"} .set(builder::ContractionOutput{"O"}
.add_indices( .add_indices([&](
[&](std::back_insert_iterator<std::list<std::string>> out) { std::back_insert_iterator<std::list<std::string>>
for (std::size_t idx = 0; idx < input_shape.size(); ++idx) out) {
for (std::size_t idx = 0;
idx < input_shape.size();
++idx)
{ {
if (!op().get_reduction_axes().count(idx)) if (!op().get_reduction_axes().count(idx))
{ {
...@@ -215,9 +223,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()() ...@@ -215,9 +223,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
} }
} }
}) })
.add_dims( .add_dims([&](
[&](std::back_insert_iterator<std::list<std::string>> out) { std::back_insert_iterator<std::list<std::string>>
for (std::size_t idx = 0; idx < input_shape.size(); ++idx) out) {
for (std::size_t idx = 0;
idx < input_shape.size();
++idx)
{ {
if (!op().get_reduction_axes().count(idx)) if (!op().get_reduction_axes().count(idx))
{ {
...@@ -225,8 +236,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()() ...@@ -225,8 +236,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
} }
} }
})) }))
.set(builder::ContractionInput{"I"}.add_indices( .set(builder::ContractionInput{"I"}.add_indices([&](
[&](std::back_insert_iterator<std::list<std::string>> out) { std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < input_shape.size(); ++idx) for (std::size_t idx = 0; idx < input_shape.size(); ++idx)
{ {
std::size_t cidx = 0; std::size_t cidx = 0;
...@@ -244,20 +255,23 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()() ...@@ -244,20 +255,23 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
} }
set_output(result); set_output(result);
} }
// Sum reduces a tensor, summing the specified axes. // Sum reduces a tensor, summing the specified axes.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Sum>::operator()() void Impl<op::Sum>::operator()()
{ {
build_reduction("+"); build_reduction("+");
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::Max>::Registration register_max; Impl<op::Max>::Registration register_max;
ngraph::runtime::plaidml::Impl<ngraph::op::Min>::Registration register_min; Impl<op::Min>::Registration register_min;
ngraph::runtime::plaidml::Impl<ngraph::op::Product>::Registration register_product; Impl<op::Product>::Registration register_product;
ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::Registration register_reduce; Impl<op::Reduce>::Registration register_reduce;
ngraph::runtime::plaidml::Impl<ngraph::op::Sum>::Registration register_sum; Impl<op::Sum>::Registration register_sum;
}
}
}
} }
...@@ -19,10 +19,16 @@ ...@@ -19,10 +19,16 @@
#include "ngraph/op/replace_slice.hpp" #include "ngraph/op/replace_slice.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// ReplaceSlice replaces part of a tensor with another tensor. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// ReplaceSlice replaces part of a tensor with another tensor.
template <>
void Impl<op::ReplaceSlice>::operator()()
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
...@@ -43,11 +49,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()() ...@@ -43,11 +49,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()()
.add(builder::Input{op_input(0), "L"}.add_dims("D", 0, shape.size())) .add(builder::Input{op_input(0), "L"}.add_dims("D", 0, shape.size()))
.add(builder::Input{op_input(1), "S"}.add_dims("SD", 0, shape.size())) .add(builder::Input{op_input(1), "S"}.add_dims("SD", 0, shape.size()))
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::UnaryContraction{"="} .add(
.set(builder::ContractionOutput{"O"} builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"O"}
.add_dims("D", 0, shape.size()) .add_dims("D", 0, shape.size())
.add_indices( .add_indices([&](
[&](std::back_insert_iterator<std::list<std::string>> out) { std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < shape.size(); ++idx) for (std::size_t idx = 0; idx < shape.size(); ++idx)
{ {
auto stride = op().get_strides()[idx]; auto stride = op().get_strides()[idx];
...@@ -73,8 +81,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()() ...@@ -73,8 +81,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()()
out = didx.str(); out = didx.str();
} }
})) }))
.set(builder::ContractionInput{"S"}.add_indices("d", 0, shape.size())) .set(builder::ContractionInput{"S"}.add_indices(
.add_constraints([&](std::back_insert_iterator<std::list<std::string>> out) { "d", 0, shape.size()))
.add_constraints(
[&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < shape.size(); ++idx) for (std::size_t idx = 0; idx < shape.size(); ++idx)
{ {
out = "d" + std::to_string(idx) + " < " + out = "d" + std::to_string(idx) + " < " +
...@@ -84,9 +94,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()() ...@@ -84,9 +94,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()()
}) })
.set_default("L")) .set_default("L"))
.finalize()); .finalize());
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::Registration register_replace_slice; Impl<op::ReplaceSlice>::Registration register_replace_slice;
}
}
}
} }
...@@ -19,10 +19,16 @@ ...@@ -19,10 +19,16 @@
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Reverse reverses the selected axes within a tensor. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Reverse>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// Reverse reverses the selected axes within a tensor.
template <>
void Impl<op::Reverse>::operator()()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
...@@ -35,8 +41,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reverse>::operator()() ...@@ -35,8 +41,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reverse>::operator()()
.set(builder::ContractionOutput{"O"} .set(builder::ContractionOutput{"O"}
.add_indices("d", 0, shape.size()) .add_indices("d", 0, shape.size())
.add_dims("D", 0, shape.size())) .add_dims("D", 0, shape.size()))
.set(builder::ContractionInput{"I"}.add_indices( .set(builder::ContractionInput{"I"}.add_indices([&](
[&](std::back_insert_iterator<std::list<std::string>> out) { std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < shape.size(); ++idx) for (std::size_t idx = 0; idx < shape.size(); ++idx)
{ {
auto sidx = std::to_string(idx); auto sidx = std::to_string(idx);
...@@ -51,9 +57,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reverse>::operator()() ...@@ -51,9 +57,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reverse>::operator()()
} }
}))) })))
.finalize()); .finalize());
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::Reverse>::Registration register_reverse; Impl<op::Reverse>::Registration register_reverse;
}
}
}
} }
...@@ -18,10 +18,16 @@ ...@@ -18,10 +18,16 @@
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Slice takes a sub-slice of a tensor. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Slice>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// Slice takes a sub-slice of a tensor.
template <>
void Impl<op::Slice>::operator()()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
NGRAPH_DEBUG << "Slice: low: " << op().get_lower_bounds(); NGRAPH_DEBUG << "Slice: low: " << op().get_lower_bounds();
...@@ -33,17 +39,21 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Slice>::operator()() ...@@ -33,17 +39,21 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Slice>::operator()()
start_tile_function() start_tile_function()
.add(builder::Input{op_input(), "I"}.add_dims("ID", 0, dim_limit)) .add(builder::Input{op_input(), "I"}.add_dims("ID", 0, dim_limit))
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::UnaryContraction{"="} .add(
.set(builder::ContractionOutput{"O"} builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"O"}
.add_indices("od", 0, dim_limit) .add_indices("od", 0, dim_limit)
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) { .add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx) for (std::size_t idx = 0; idx < dim_limit; ++idx)
{ {
std::ostringstream s; std::ostringstream s;
std::size_t stride = op().get_strides()[idx]; std::size_t stride = op().get_strides()[idx];
std::ptrdiff_t trim_count = std::ptrdiff_t trim_count =
op().get_lower_bounds()[idx] + op().get_lower_bounds()[idx] +
(shape[idx] - op().get_upper_bounds()[idx]) + 1 - stride; (shape[idx] - op().get_upper_bounds()[idx]) +
1 - stride;
if ((stride != 1) && trim_count) if ((stride != 1) && trim_count)
{ {
s << "("; s << "(";
...@@ -96,9 +106,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Slice>::operator()() ...@@ -96,9 +106,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Slice>::operator()()
} }
}))) })))
.finalize()); .finalize());
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::Slice>::Registration register_slice; Impl<op::Slice>::Registration register_slice;
}
}
}
} }
...@@ -19,10 +19,16 @@ ...@@ -19,10 +19,16 @@
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Softmax implements a standard ML softmax operation. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// Softmax implements a standard ML softmax operation.
template <>
void Impl<op::Softmax>::operator()()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
...@@ -30,7 +36,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()() ...@@ -30,7 +36,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
auto dim_limit = shape.size(); auto dim_limit = shape.size();
auto f = start_tile_function(); auto f = start_tile_function();
f.add(builder::Input{op_input(0), "I"}.add_dims("D", 0, dim_limit)).add(builder::Output{"O"}); f.add(builder::Input{op_input(0), "I"}.add_dims("D", 0, dim_limit))
.add(builder::Output{"O"});
bool reorder_needed = false; bool reorder_needed = false;
bool saw_element = false; bool saw_element = false;
...@@ -71,7 +78,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()() ...@@ -71,7 +78,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
{ {
f.add(builder::UnaryContraction{"="} f.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"RI"} .set(builder::ContractionOutput{"RI"}
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) { .add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx : group_idxs) for (auto idx : group_idxs)
{ {
out = "D" + std::to_string(idx); out = "D" + std::to_string(idx);
...@@ -81,7 +89,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()() ...@@ -81,7 +89,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
out = "D" + std::to_string(idx); out = "D" + std::to_string(idx);
} }
}) })
.add_indices([&](std::back_insert_iterator<std::list<std::string>> out) { .add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx : group_idxs) for (auto idx : group_idxs)
{ {
out = "d" + std::to_string(idx); out = "d" + std::to_string(idx);
...@@ -117,7 +126,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()() ...@@ -117,7 +126,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
{ {
// Take the softmax. // Take the softmax.
std::ostringstream softmax; std::ostringstream softmax;
softmax << "builtin_softmax(" << input << ", " << groups << ", " << elements << ")"; softmax << "builtin_softmax(" << input << ", " << groups << ", " << elements
<< ")";
f.add(builder::Elementwise{output, softmax.str()}); f.add(builder::Elementwise{output, softmax.str()});
} }
...@@ -159,9 +169,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()() ...@@ -159,9 +169,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
} }
set_output(f.finalize()); set_output(f.finalize());
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::Registration register_softmax; Impl<op::Softmax>::Registration register_softmax;
}
}
}
} }
...@@ -29,10 +29,16 @@ ...@@ -29,10 +29,16 @@
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// acos performs a simple elementwise arccos function. namespace ngraph
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Acos>::operator()()
{ {
namespace runtime
{
namespace plaidml
{
// acos performs a simple elementwise arccos function.
template <>
void Impl<op::Acos>::operator()()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -40,12 +46,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Acos>::operator()() ...@@ -40,12 +46,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Acos>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "acos(I)"}) .add(builder::Elementwise{"O", "acos(I)"})
.finalize()); .finalize());
} }
// asin performs a simple elementwise arcsin function. // asin performs a simple elementwise arcsin function.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Asin>::operator()() void Impl<op::Asin>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -53,12 +59,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Asin>::operator()() ...@@ -53,12 +59,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Asin>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "asin(I)"}) .add(builder::Elementwise{"O", "asin(I)"})
.finalize()); .finalize());
} }
// atan performs a simple elementwise arctan function. // atan performs a simple elementwise arctan function.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Atan>::operator()() void Impl<op::Atan>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -66,12 +72,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Atan>::operator()() ...@@ -66,12 +72,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Atan>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "atan(I)"}) .add(builder::Elementwise{"O", "atan(I)"})
.finalize()); .finalize());
} }
// cos performs a simple elementwise cos function. // cos performs a simple elementwise cos function.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Cos>::operator()() void Impl<op::Cos>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -79,12 +85,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Cos>::operator()() ...@@ -79,12 +85,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Cos>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "cos(I)"}) .add(builder::Elementwise{"O", "cos(I)"})
.finalize()); .finalize());
} }
// cosh performs a simple elementwise hyperbolic cos function. // cosh performs a simple elementwise hyperbolic cos function.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Cosh>::operator()() void Impl<op::Cosh>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -92,12 +98,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Cosh>::operator()() ...@@ -92,12 +98,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Cosh>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "cosh(I)"}) .add(builder::Elementwise{"O", "cosh(I)"})
.finalize()); .finalize());
} }
// exp performs a simple elementwise natural exponential function. // exp performs a simple elementwise natural exponential function.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Exp>::operator()() void Impl<op::Exp>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -105,12 +111,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Exp>::operator()() ...@@ -105,12 +111,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Exp>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "exp(I)"}) .add(builder::Elementwise{"O", "exp(I)"})
.finalize()); .finalize());
} }
// log performs a simple elementwise natural logarithm function. // log performs a simple elementwise natural logarithm function.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Log>::operator()() void Impl<op::Log>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -118,12 +124,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Log>::operator()() ...@@ -118,12 +124,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Log>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "log(I)"}) .add(builder::Elementwise{"O", "log(I)"})
.finalize()); .finalize());
} }
// power performs a simple elementwise power function. // power performs a simple elementwise power function.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Power>::operator()() void Impl<op::Power>::operator()()
{ {
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -132,12 +138,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Power>::operator()() ...@@ -132,12 +138,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Power>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "pow(I, E)"}) .add(builder::Elementwise{"O", "pow(I, E)"})
.finalize()); .finalize());
} }
// sin performs a simple elementwise sin function. // sin performs a simple elementwise sin function.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Sin>::operator()() void Impl<op::Sin>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -145,12 +151,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sin>::operator()() ...@@ -145,12 +151,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sin>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "sin(I)"}) .add(builder::Elementwise{"O", "sin(I)"})
.finalize()); .finalize());
} }
// sinh performs a simple elementwise hyperbolic sin function. // sinh performs a simple elementwise hyperbolic sin function.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Sinh>::operator()() void Impl<op::Sinh>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -158,12 +164,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sinh>::operator()() ...@@ -158,12 +164,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sinh>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "sinh(I)"}) .add(builder::Elementwise{"O", "sinh(I)"})
.finalize()); .finalize());
} }
// sqrt performs a simple elementwise square root function. // sqrt performs a simple elementwise square root function.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Sqrt>::operator()() void Impl<op::Sqrt>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -171,12 +177,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sqrt>::operator()() ...@@ -171,12 +177,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sqrt>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "sqrt(I)"}) .add(builder::Elementwise{"O", "sqrt(I)"})
.finalize()); .finalize());
} }
// tan performs a simple elementwise tangent function. // tan performs a simple elementwise tangent function.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Tan>::operator()() void Impl<op::Tan>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -184,12 +190,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Tan>::operator()() ...@@ -184,12 +190,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Tan>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "tan(I)"}) .add(builder::Elementwise{"O", "tan(I)"})
.finalize()); .finalize());
} }
// tanh performs a simple elementwise hyperbolic tangent function. // tanh performs a simple elementwise hyperbolic tangent function.
template <> template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Tanh>::operator()() void Impl<op::Tanh>::operator()()
{ {
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -197,21 +203,24 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Tanh>::operator()() ...@@ -197,21 +203,24 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Tanh>::operator()()
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "tanh(I)"}) .add(builder::Elementwise{"O", "tanh(I)"})
.finalize()); .finalize());
} }
namespace namespace
{ {
ngraph::runtime::plaidml::Impl<ngraph::op::Acos>::Registration register_acos; Impl<op::Acos>::Registration register_acos;
ngraph::runtime::plaidml::Impl<ngraph::op::Asin>::Registration register_asin; Impl<op::Asin>::Registration register_asin;
ngraph::runtime::plaidml::Impl<ngraph::op::Atan>::Registration register_atan; Impl<op::Atan>::Registration register_atan;
ngraph::runtime::plaidml::Impl<ngraph::op::Cos>::Registration register_cos; Impl<op::Cos>::Registration register_cos;
ngraph::runtime::plaidml::Impl<ngraph::op::Cosh>::Registration register_cosh; Impl<op::Cosh>::Registration register_cosh;
ngraph::runtime::plaidml::Impl<ngraph::op::Exp>::Registration register_exp; Impl<op::Exp>::Registration register_exp;
ngraph::runtime::plaidml::Impl<ngraph::op::Log>::Registration register_log; Impl<op::Log>::Registration register_log;
ngraph::runtime::plaidml::Impl<ngraph::op::Power>::Registration register_power; Impl<op::Power>::Registration register_power;
ngraph::runtime::plaidml::Impl<ngraph::op::Sin>::Registration register_sin; Impl<op::Sin>::Registration register_sin;
ngraph::runtime::plaidml::Impl<ngraph::op::Sinh>::Registration register_sinh; Impl<op::Sinh>::Registration register_sinh;
ngraph::runtime::plaidml::Impl<ngraph::op::Sqrt>::Registration register_sqrt; Impl<op::Sqrt>::Registration register_sqrt;
ngraph::runtime::plaidml::Impl<ngraph::op::Tan>::Registration register_tan; Impl<op::Tan>::Registration register_tan;
ngraph::runtime::plaidml::Impl<ngraph::op::Tanh>::Registration register_tanh; Impl<op::Tanh>::Registration register_tanh;
}
}
}
} }
...@@ -38,6 +38,8 @@ topk_2d_max_one # No plans to implement TopK ...@@ -38,6 +38,8 @@ topk_2d_max_one # No plans to implement TopK
topk_2d_min_all # No plans to implement TopK topk_2d_min_all # No plans to implement TopK
topk_2d_min_partial # No plans to implement TopK topk_2d_min_partial # No plans to implement TopK
topk_2d_min_one # No plans to implement TopK topk_2d_min_one # No plans to implement TopK
topk_int64 # No plans to implement TopK
topk_5d_max_partial # No plans to implement TopK
# Tests that PlaidML might be able to run at some point. # Tests that PlaidML might be able to run at some point.
backwards_maxpool_n2_c1_hw5_3x3_str2_max_pad1x2_2x3 backwards_maxpool_n2_c1_hw5_3x3_str2_max_pad1x2_2x3
...@@ -84,3 +86,5 @@ sum_3d_eliminate_zero_dim # Empty dims apparently should produce shape ...@@ -84,3 +86,5 @@ sum_3d_eliminate_zero_dim # Empty dims apparently should produce shape
dot_0_0 # Empty dims apparently should produce shaped 0s dot_0_0 # Empty dims apparently should produce shaped 0s
dot_matrix_2x0_0x2 # Empty dims apparently should produce shaped 0s dot_matrix_2x0_0x2 # Empty dims apparently should produce shaped 0s
dot_2x0_0 # Empty dims apparently should produce shaped 0s dot_2x0_0 # Empty dims apparently should produce shaped 0s
numeric_float_nan
numeric_double_nan
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment