Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
61df6725
Commit
61df6725
authored
Oct 31, 2018
by
Rob Earhart
Committed by
Robert Kimball
Oct 31, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PlaidML] Specialize within namespaces (for Linux) (#1948)
parent
5698fa75
Hide whitespace changes
Inline
Side-by-side
Showing
22 changed files
with
2816 additions
and
2566 deletions
+2816
-2566
plaidml_ops_arithmetic.cpp
src/ngraph/runtime/plaidml/plaidml_ops_arithmetic.cpp
+189
-180
plaidml_ops_batch_norm.cpp
src/ngraph/runtime/plaidml/plaidml_ops_batch_norm.cpp
+284
-258
plaidml_ops_comparison.cpp
src/ngraph/runtime/plaidml/plaidml_ops_comparison.cpp
+128
-119
plaidml_ops_concat.cpp
src/ngraph/runtime/plaidml/plaidml_ops_concat.cpp
+91
-77
plaidml_ops_convert.cpp
src/ngraph/runtime/plaidml/plaidml_ops_convert.cpp
+25
-15
plaidml_ops_convolution.cpp
src/ngraph/runtime/plaidml/plaidml_ops_convolution.cpp
+207
-209
plaidml_ops_dot.cpp
src/ngraph/runtime/plaidml/plaidml_ops_dot.cpp
+51
-41
plaidml_ops_function.cpp
src/ngraph/runtime/plaidml/plaidml_ops_function.cpp
+32
-22
plaidml_ops_general.cpp
src/ngraph/runtime/plaidml/plaidml_ops_general.cpp
+439
-409
plaidml_ops_index_reduction.cpp
src/ngraph/runtime/plaidml/plaidml_ops_index_reduction.cpp
+105
-98
plaidml_ops_io.cpp
src/ngraph/runtime/plaidml/plaidml_ops_io.cpp
+37
-28
plaidml_ops_local_response_norm.cpp
...graph/runtime/plaidml/plaidml_ops_local_response_norm.cpp
+46
-34
plaidml_ops_logical.cpp
src/ngraph/runtime/plaidml/plaidml_ops_logical.cpp
+54
-45
plaidml_ops_one_hot.cpp
src/ngraph/runtime/plaidml/plaidml_ops_one_hot.cpp
+83
-70
plaidml_ops_pool.cpp
src/ngraph/runtime/plaidml/plaidml_ops_pool.cpp
+293
-284
plaidml_ops_reduce.cpp
src/ngraph/runtime/plaidml/plaidml_ops_reduce.cpp
+209
-195
plaidml_ops_replace_slice.cpp
src/ngraph/runtime/plaidml/plaidml_ops_replace_slice.cpp
+78
-65
plaidml_ops_reverse.cpp
src/ngraph/runtime/plaidml/plaidml_ops_reverse.cpp
+42
-33
plaidml_ops_slice.cpp
src/ngraph/runtime/plaidml/plaidml_ops_slice.cpp
+94
-81
plaidml_ops_softmax.cpp
src/ngraph/runtime/plaidml/plaidml_ops_softmax.cpp
+145
-132
plaidml_ops_transcendental.cpp
src/ngraph/runtime/plaidml/plaidml_ops_transcendental.cpp
+180
-171
unit_test.manifest
src/ngraph/runtime/plaidml/unit_test.manifest
+4
-0
No files found.
src/ngraph/runtime/plaidml/plaidml_ops_arithmetic.cpp
View file @
61df6725
...
...
@@ -28,198 +28,207 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp"
// Abs performs a simple elementwise absolute value.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Abs
>::
operator
()()
namespace
ngraph
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"abs(I)"
})
.
finalize
());
}
namespace
runtime
{
namespace
plaidml
{
// Abs performs a simple elementwise absolute value.
template
<>
void
Impl
<
op
::
Abs
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"abs(I)"
})
.
finalize
());
}
// Add performs a simple elementwise addition.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Add
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A + B"
})
.
finalize
());
}
// Add performs a simple elementwise addition.
template
<>
void
Impl
<
op
::
Add
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A + B"
})
.
finalize
());
}
// Ceiling performs a simple elementwise ceiling.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Ceiling
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"ceil(I)"
})
.
finalize
());
}
// Ceiling performs a simple elementwise ceiling.
template
<>
void
Impl
<
op
::
Ceiling
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"ceil(I)"
})
.
finalize
());
}
// Divide performs a simple elementwise division.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Divide
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A / B"
})
.
finalize
());
}
// Divide performs a simple elementwise division.
template
<>
void
Impl
<
op
::
Divide
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A / B"
})
.
finalize
());
}
// Floor performs a simple elementwise floor.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Floor
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"floor(I)"
})
.
finalize
());
}
// Floor performs a simple elementwise floor.
template
<>
void
Impl
<
op
::
Floor
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"floor(I)"
})
.
finalize
());
}
// Multiply performs a simple elementwise multiplication.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Multiply
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A * B"
})
.
finalize
());
}
// Multiply performs a simple elementwise multiplication.
template
<>
void
Impl
<
op
::
Multiply
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A * B"
})
.
finalize
());
}
// Negative performs a simple elementwise negation.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Negative
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"-I"
})
.
finalize
());
}
// Negative performs a simple elementwise negation.
template
<>
void
Impl
<
op
::
Negative
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"-I"
})
.
finalize
());
}
// Relu implements a simple elementwise rectified linear unit.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Relu
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"relu(I)"
})
.
finalize
());
}
// Relu implements a simple elementwise rectified linear unit.
template
<>
void
Impl
<
op
::
Relu
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"relu(I)"
})
.
finalize
());
}
// ReluBackprop computes the derivative of Relu.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ReluBackprop
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"DO"
})
.
add
(
builder
::
Output
{
"DI"
})
.
add
(
builder
::
Elementwise
{
"DI"
,
"I > 0 ? DO : 0"
})
.
finalize
());
}
// ReluBackprop computes the derivative of Relu.
template
<>
void
Impl
<
op
::
ReluBackprop
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"DO"
})
.
add
(
builder
::
Output
{
"DI"
})
.
add
(
builder
::
Elementwise
{
"DI"
,
"I > 0 ? DO : 0"
})
.
finalize
());
}
// Sigmoid computes a standard ML sigmoid: 1/(1+exp(-X))
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sigmoid
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"1/(1+exp(-I))"
})
.
finalize
());
}
// Sigmoid computes a standard ML sigmoid: 1/(1+exp(-X))
template
<>
void
Impl
<
op
::
Sigmoid
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"1/(1+exp(-I))"
})
.
finalize
());
}
// SigmoidBackprop computes the derivative of a standard ML
// sigmoid: dOutput * sigmoid(X) * (1-sigmoid(X))
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
SigmoidBackprop
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"DO"
})
.
add
(
builder
::
Output
{
"DI"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"1/(1+exp(-I))"
})
.
add
(
builder
::
Elementwise
{
"DI"
,
"DO * O * (1-O)"
})
.
finalize
());
}
// SigmoidBackprop computes the derivative of a standard ML
// sigmoid: dOutput * sigmoid(X) * (1-sigmoid(X))
template
<>
void
Impl
<
op
::
SigmoidBackprop
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"DO"
})
.
add
(
builder
::
Output
{
"DI"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"1/(1+exp(-I))"
})
.
add
(
builder
::
Elementwise
{
"DI"
,
"DO * O * (1-O)"
})
.
finalize
());
}
// Sign returns the sign of an element.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sign
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"S"
,
"(I < 0) ? -1 : ((I > 0) ? 1 : 0)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
tile_converter
(
"S"
,
op
().
get_element_type
())})
.
finalize
());
}
// Sign returns the sign of an element.
template
<>
void
Impl
<
op
::
Sign
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"S"
,
"(I < 0) ? -1 : ((I > 0) ? 1 : 0)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
tile_converter
(
"S"
,
op
().
get_element_type
())})
.
finalize
());
}
// Subtract performs a simple elementwise subtraction.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Subtract
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A - B"
})
.
finalize
());
}
// Subtract performs a simple elementwise subtraction.
template
<>
void
Impl
<
op
::
Subtract
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A - B"
})
.
finalize
());
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Abs
>::
Registration
register_abs
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Add
>::
Registration
register_add
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Ceiling
>::
Registration
register_ceiling
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Divide
>::
Registration
register_divide
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Floor
>::
Registration
register_floor
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Multiply
>::
Registration
register_multiply
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Negative
>::
Registration
register_negative
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Relu
>::
Registration
register_relu
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ReluBackprop
>::
Registration
register_relu_backprop
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sigmoid
>::
Registration
register_sigmoid
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
SigmoidBackprop
>::
Registration
register_sigmoid_backprop
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sign
>::
Registration
register_sign
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Subtract
>::
Registration
register_subtract
;
namespace
{
Impl
<
op
::
Abs
>::
Registration
register_abs
;
Impl
<
op
::
Add
>::
Registration
register_add
;
Impl
<
op
::
Ceiling
>::
Registration
register_ceiling
;
Impl
<
op
::
Divide
>::
Registration
register_divide
;
Impl
<
op
::
Floor
>::
Registration
register_floor
;
Impl
<
op
::
Multiply
>::
Registration
register_multiply
;
Impl
<
op
::
Negative
>::
Registration
register_negative
;
Impl
<
op
::
Relu
>::
Registration
register_relu
;
Impl
<
op
::
ReluBackprop
>::
Registration
register_relu_backprop
;
Impl
<
op
::
Sigmoid
>::
Registration
register_sigmoid
;
Impl
<
op
::
SigmoidBackprop
>::
Registration
register_sigmoid_backprop
;
Impl
<
op
::
Sign
>::
Registration
register_sign
;
Impl
<
op
::
Subtract
>::
Registration
register_subtract
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_batch_norm.cpp
View file @
61df6725
...
...
@@ -18,291 +18,317 @@
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// BatchNormInference implements batch normalization for inference, in
// which the mean and variance to use are supplied.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
BatchNormInference
>::
operator
()()
namespace
ngraph
{
auto
&
input_shape
=
op
().
get_input_shape
(
2
);
check_inputs
(
5
);
check_outputs
(
1
);
namespace
runtime
{
namespace
plaidml
{
// BatchNormInference implements batch normalization for inference, in
// which the mean and variance to use are supplied.
template
<>
void
Impl
<
op
::
BatchNormInference
>::
operator
()()
{
auto
&
input_shape
=
op
().
get_input_shape
(
2
);
check_inputs
(
5
);
check_outputs
(
1
);
auto
f
=
start_tile_function
();
f
.
add
(
builder
::
Input
{
op_input
(
0
),
"Gamma"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
1
),
"Beta"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
2
),
"Input"
}
.
add_dims
({
"B"
,
"C"
})
.
add_dims
(
"DI"
,
3
,
input_shape
.
size
()
+
1
))
.
add
(
builder
::
Output
{
"Normalized"
})
.
add
(
builder
::
Input
{
op_input
(
3
),
"Mean"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
4
),
"Variance"
}.
add_dims
({
"C"
}));
auto
f
=
start_tile_function
();
f
.
add
(
builder
::
Input
{
op_input
(
0
),
"Gamma"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
1
),
"Beta"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
2
),
"Input"
}
.
add_dims
({
"B"
,
"C"
})
.
add_dims
(
"DI"
,
3
,
input_shape
.
size
()
+
1
))
.
add
(
builder
::
Output
{
"Normalized"
})
.
add
(
builder
::
Input
{
op_input
(
3
),
"Mean"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
4
),
"Variance"
}.
add_dims
({
"C"
}));
std
::
string
ones
;
for
(
auto
idx
=
2
;
idx
<
input_shape
.
size
();
++
idx
)
{
ones
+=
", 1"
;
}
std
::
string
ones
;
for
(
auto
idx
=
2
;
idx
<
input_shape
.
size
();
++
idx
)
{
ones
+=
", 1"
;
}
if
(
input_shape
.
size
()
<=
2
)
{
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
"Gamma"
}).
add
(
builder
::
Elementwise
{
"BetaP"
,
"Beta"
});
}
else
{
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
std
::
string
{
"reshape(Gamma, C"
}
+
ones
+
")"
})
.
add
(
builder
::
Elementwise
{
"BetaP"
,
std
::
string
{
"reshape(Beta, C"
}
+
ones
+
")"
});
}
if
(
input_shape
.
size
()
<=
2
)
{
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
"Gamma"
})
.
add
(
builder
::
Elementwise
{
"BetaP"
,
"Beta"
});
}
else
{
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
std
::
string
{
"reshape(Gamma, C"
}
+
ones
+
")"
})
.
add
(
builder
::
Elementwise
{
"BetaP"
,
std
::
string
{
"reshape(Beta, C"
}
+
ones
+
")"
});
}
if
(
input_shape
.
size
()
<=
2
)
{
f
.
add
(
builder
::
Elementwise
{
"MeanP"
,
"Mean"
});
}
else
{
f
.
add
(
builder
::
Elementwise
{
"MeanP"
,
std
::
string
{
"reshape(Mean, C"
}
+
ones
+
")"
});
}
if
(
input_shape
.
size
()
<=
2
)
{
f
.
add
(
builder
::
Elementwise
{
"MeanP"
,
"Mean"
});
}
else
{
f
.
add
(
builder
::
Elementwise
{
"MeanP"
,
std
::
string
{
"reshape(Mean, C"
}
+
ones
+
")"
});
}
if
(
input_shape
.
size
()
<=
2
)
{
f
.
add
(
builder
::
Elementwise
{
"VarianceP"
,
"Variance"
});
}
else
{
f
.
add
(
builder
::
Elementwise
{
"VarianceP"
,
std
::
string
{
"reshape(Variance, C"
}
+
ones
+
")"
});
}
if
(
input_shape
.
size
()
<=
2
)
{
f
.
add
(
builder
::
Elementwise
{
"VarianceP"
,
"Variance"
});
}
else
{
f
.
add
(
builder
::
Elementwise
{
"VarianceP"
,
std
::
string
{
"reshape(Variance, C"
}
+
ones
+
")"
});
}
f
.
add
(
builder
::
Elementwise
{
"Normalized"
,
"(((Input-MeanP) / sqrt(VarianceP + "
+
std
::
to_string
(
op
().
get_eps_value
())
+
")) * GammaP) + BetaP"
});
f
.
add
(
builder
::
Elementwise
{
"Normalized"
,
"(((Input-MeanP) / sqrt(VarianceP + "
+
std
::
to_string
(
op
().
get_eps_value
())
+
")) * GammaP) + BetaP"
});
auto
app
=
f
.
finalize
();
auto
app
=
f
.
finalize
();
set_output
(
app
);
}
set_output
(
app
);
}
// BatchNormTraining implements batch normalization for training, in
// which the mean and variance are to be computed from the supplied
// input.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
BatchNormTraining
>::
operator
()()
{
auto
&
input_shape
=
op
().
get_input_shape
(
2
);
check_inputs
(
3
);
check_outputs
(
3
);
// BatchNormTraining implements batch normalization for training, in
// which the mean and variance are to be computed from the supplied
// input.
template
<>
void
Impl
<
op
::
BatchNormTraining
>::
operator
()()
{
auto
&
input_shape
=
op
().
get_input_shape
(
2
);
check_inputs
(
3
);
check_outputs
(
3
);
auto
f
=
start_tile_function
();
f
.
add
(
builder
::
Input
{
op_input
(
0
),
"Gamma"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
1
),
"Beta"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
2
),
"Input"
}
.
add_dims
({
"B"
,
"C"
})
.
add_dims
(
"DI"
,
3
,
input_shape
.
size
()
+
1
))
.
add
(
builder
::
Output
{
"Normalized"
})
.
add
(
builder
::
Output
{
"Mean"
})
.
add
(
builder
::
Output
{
"Variance"
});
auto
f
=
start_tile_function
();
f
.
add
(
builder
::
Input
{
op_input
(
0
),
"Gamma"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
1
),
"Beta"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
2
),
"Input"
}
.
add_dims
({
"B"
,
"C"
})
.
add_dims
(
"DI"
,
3
,
input_shape
.
size
()
+
1
))
.
add
(
builder
::
Output
{
"Normalized"
})
.
add
(
builder
::
Output
{
"Mean"
})
.
add
(
builder
::
Output
{
"Variance"
});
std
::
string
ones
;
for
(
auto
idx
=
2
;
idx
<
input_shape
.
size
();
++
idx
)
{
ones
+=
", 1"
;
}
std
::
string
ones
;
for
(
auto
idx
=
2
;
idx
<
input_shape
.
size
();
++
idx
)
{
ones
+=
", 1"
;
}
if
(
input_shape
.
size
()
<=
2
)
{
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
"Gamma"
}).
add
(
builder
::
Elementwise
{
"BetaP"
,
"Beta"
});
}
else
{
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
std
::
string
{
"reshape(Gamma, C"
}
+
ones
+
")"
})
.
add
(
builder
::
Elementwise
{
"BetaP"
,
std
::
string
{
"reshape(Beta, C"
}
+
ones
+
")"
});
}
if
(
input_shape
.
size
()
<=
2
)
{
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
"Gamma"
})
.
add
(
builder
::
Elementwise
{
"BetaP"
,
"Beta"
});
}
else
{
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
std
::
string
{
"reshape(Gamma, C"
}
+
ones
+
")"
})
.
add
(
builder
::
Elementwise
{
"BetaP"
,
std
::
string
{
"reshape(Beta, C"
}
+
ones
+
")"
});
}
if
(
input_shape
.
size
()
<=
2
)
{
f
.
add
(
builder
::
Elementwise
{
"EltCount"
,
"B"
});
}
else
{
std
::
string
elts
{
"B"
};
for
(
auto
idx
=
2
;
idx
<
input_shape
.
size
();
++
idx
)
{
elts
+=
" * DI"
+
std
::
to_string
(
idx
+
1
);
}
f
.
add
(
builder
::
Elementwise
{
"EltCount"
,
std
::
move
(
elts
)});
}
if
(
input_shape
.
size
()
<=
2
)
{
f
.
add
(
builder
::
Elementwise
{
"EltCount"
,
"B"
});
}
else
{
std
::
string
elts
{
"B"
};
for
(
auto
idx
=
2
;
idx
<
input_shape
.
size
();
++
idx
)
{
elts
+=
" * DI"
+
std
::
to_string
(
idx
+
1
);
}
f
.
add
(
builder
::
Elementwise
{
"EltCount"
,
std
::
move
(
elts
)});
}
f
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
set
(
builder
::
ContractionOutput
{
"SumInput"
}.
add_indices
({
"c"
}).
add_dims
({
"C"
}))
.
set
(
builder
::
ContractionInput
{
"Input"
}
.
add_indices
({
"b"
,
"c"
})
.
add_indices
(
"di"
,
3
,
input_shape
.
size
()
+
1
)));
f
.
add
(
builder
::
Elementwise
{
"Mean"
,
"SumInput / EltCount"
});
f
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
set
(
builder
::
ContractionOutput
{
"SumInput"
}.
add_indices
({
"c"
}).
add_dims
(
{
"C"
}))
.
set
(
builder
::
ContractionInput
{
"Input"
}
.
add_indices
({
"b"
,
"c"
})
.
add_indices
(
"di"
,
3
,
input_shape
.
size
()
+
1
)));
f
.
add
(
builder
::
Elementwise
{
"Mean"
,
"SumInput / EltCount"
});
if
(
input_shape
.
size
()
<=
2
)
{
f
.
add
(
builder
::
Elementwise
{
"MeanP"
,
"Mean"
});
}
else
{
f
.
add
(
builder
::
Elementwise
{
"MeanP"
,
std
::
string
{
"reshape(Mean, C"
}
+
ones
+
")"
});
}
if
(
input_shape
.
size
()
<=
2
)
{
f
.
add
(
builder
::
Elementwise
{
"MeanP"
,
"Mean"
});
}
else
{
f
.
add
(
builder
::
Elementwise
{
"MeanP"
,
std
::
string
{
"reshape(Mean, C"
}
+
ones
+
")"
});
}
f
.
add
(
builder
::
Elementwise
{
"DiffV"
,
"(Input - MeanP)"
})
.
add
(
builder
::
Elementwise
{
"SqDiffV"
,
"DiffV*DiffV"
})
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
set
(
builder
::
ContractionOutput
{
"SumSqDiffV"
}.
add_indices
({
"c"
}).
add_dims
({
"C"
}))
.
set
(
builder
::
ContractionInput
{
"SqDiffV"
}
.
add_indices
({
"b"
,
"c"
})
.
add_indices
(
"di"
,
3
,
input_shape
.
size
()
+
1
)))
.
add
(
builder
::
Elementwise
{
"Variance"
,
"SumSqDiffV / EltCount"
});
f
.
add
(
builder
::
Elementwise
{
"DiffV"
,
"(Input - MeanP)"
})
.
add
(
builder
::
Elementwise
{
"SqDiffV"
,
"DiffV*DiffV"
})
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
set
(
builder
::
ContractionOutput
{
"SumSqDiffV"
}
.
add_indices
({
"c"
})
.
add_dims
({
"C"
}))
.
set
(
builder
::
ContractionInput
{
"SqDiffV"
}
.
add_indices
({
"b"
,
"c"
})
.
add_indices
(
"di"
,
3
,
input_shape
.
size
()
+
1
)))
.
add
(
builder
::
Elementwise
{
"Variance"
,
"SumSqDiffV / EltCount"
});
if
(
input_shape
.
size
()
<=
2
)
{
f
.
add
(
builder
::
Elementwise
{
"VarianceP"
,
"Variance"
});
}
else
{
f
.
add
(
builder
::
Elementwise
{
"VarianceP"
,
std
::
string
{
"reshape(Variance, C"
}
+
ones
+
")"
});
}
if
(
input_shape
.
size
()
<=
2
)
{
f
.
add
(
builder
::
Elementwise
{
"VarianceP"
,
"Variance"
});
}
else
{
f
.
add
(
builder
::
Elementwise
{
"VarianceP"
,
std
::
string
{
"reshape(Variance, C"
}
+
ones
+
")"
});
}
f
.
add
(
builder
::
Elementwise
{
"Normalized"
,
"(((Input-MeanP) / sqrt(VarianceP + "
+
std
::
to_string
(
op
().
get_eps_value
())
+
")) * GammaP) + BetaP"
});
f
.
add
(
builder
::
Elementwise
{
"Normalized"
,
"(((Input-MeanP) / sqrt(VarianceP + "
+
std
::
to_string
(
op
().
get_eps_value
())
+
")) * GammaP) + BetaP"
});
auto
app
=
f
.
finalize
();
auto
app
=
f
.
finalize
();
set_output
(
0
,
app
.
get_output
(
0
));
set_output
(
1
,
app
.
get_output
(
1
));
set_output
(
2
,
app
.
get_output
(
2
));
}
set_output
(
0
,
app
.
get_output
(
0
));
set_output
(
1
,
app
.
get_output
(
1
));
set_output
(
2
,
app
.
get_output
(
2
));
}
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
BatchNormTrainingBackprop
>::
operator
()()
{
// WARNING: I'm unconvinced that we have sufficient test converage for BatchNorm
// backprop and in particular I'm concerned that Gamma/Beta and Mean/Var could be
// swapped without the tests catching it.
check_inputs
(
6
);
check_outputs
(
3
);
auto
&
input_shape
=
op
().
get_input_shape
(
2
);
std
::
string
epsilon
=
std
::
to_string
(
op
().
get_eps_value
());
template
<>
void
Impl
<
op
::
BatchNormTrainingBackprop
>::
operator
()()
{
// WARNING: I'm unconvinced that we have sufficient test converage for BatchNorm
// backprop and in particular I'm concerned that Gamma/Beta and Mean/Var could be
// swapped without the tests catching it.
check_inputs
(
6
);
check_outputs
(
3
);
auto
&
input_shape
=
op
().
get_input_shape
(
2
);
std
::
string
epsilon
=
std
::
to_string
(
op
().
get_eps_value
());
auto
f
=
start_tile_function
();
// Header
f
.
add
(
builder
::
Input
{
op_input
(
0
),
"Gamma"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
1
),
"Beta"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
2
),
"Input"
}
.
add_dims
({
"N"
,
"C"
})
.
add_dims
(
"X"
,
3
,
input_shape
.
size
()
+
1
))
.
add
(
builder
::
Input
{
op_input
(
3
),
"Mean"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
4
),
"Var"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
5
),
"DOutput"
}
.
add_dims
({
"N"
,
"C"
})
.
add_dims
(
"X"
,
3
,
input_shape
.
size
()
+
1
));
f
.
add
(
builder
::
Output
{
"DInput"
});
f
.
add
(
builder
::
Output
{
"DGamma"
});
f
.
add
(
builder
::
Output
{
"DBeta"
});
auto
f
=
start_tile_function
();
// Header
f
.
add
(
builder
::
Input
{
op_input
(
0
),
"Gamma"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
1
),
"Beta"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
2
),
"Input"
}
.
add_dims
({
"N"
,
"C"
})
.
add_dims
(
"X"
,
3
,
input_shape
.
size
()
+
1
))
.
add
(
builder
::
Input
{
op_input
(
3
),
"Mean"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
4
),
"Var"
}.
add_dims
({
"C"
}))
.
add
(
builder
::
Input
{
op_input
(
5
),
"DOutput"
}
.
add_dims
({
"N"
,
"C"
})
.
add_dims
(
"X"
,
3
,
input_shape
.
size
()
+
1
));
f
.
add
(
builder
::
Output
{
"DInput"
});
f
.
add
(
builder
::
Output
{
"DGamma"
});
f
.
add
(
builder
::
Output
{
"DBeta"
});
// Prep for body
builder
::
ContractionOutput
broadcast_gamma
{
"BroadcastGamma"
};
builder
::
ContractionOutput
broadcast_dgamma
{
"BroadcastDGamma"
};
builder
::
ContractionOutput
broadcast_dbeta
{
"BroadcastDBeta"
};
broadcast_gamma
.
add_indices
({
"0"
,
"c"
}).
add_dims
({
"1"
,
"C"
});
broadcast_dgamma
.
add_indices
({
"0"
,
"c"
}).
add_dims
({
"1"
,
"C"
});
broadcast_dbeta
.
add_indices
({
"0"
,
"c"
}).
add_dims
({
"1"
,
"C"
});
for
(
std
::
size_t
i
=
0
;
i
<
input_shape
.
size
()
-
2
;
++
i
)
{
broadcast_gamma
.
add_indices
({
"0"
}).
add_dims
({
"1"
});
broadcast_dgamma
.
add_indices
({
"0"
}).
add_dims
({
"1"
});
broadcast_dbeta
.
add_indices
({
"0"
}).
add_dims
({
"1"
});
}
std
::
ostringstream
reduction_dims
;
reduction_dims
<<
"("
<<
"N"
;
for
(
std
::
size_t
i
=
3
;
i
<
input_shape
.
size
()
+
1
;
++
i
)
{
reduction_dims
<<
" * X"
<<
i
;
}
reduction_dims
<<
")"
;
// Prep for body
builder
::
ContractionOutput
broadcast_gamma
{
"BroadcastGamma"
};
builder
::
ContractionOutput
broadcast_dgamma
{
"BroadcastDGamma"
};
builder
::
ContractionOutput
broadcast_dbeta
{
"BroadcastDBeta"
};
broadcast_gamma
.
add_indices
({
"0"
,
"c"
}).
add_dims
({
"1"
,
"C"
});
broadcast_dgamma
.
add_indices
({
"0"
,
"c"
}).
add_dims
({
"1"
,
"C"
});
broadcast_dbeta
.
add_indices
({
"0"
,
"c"
}).
add_dims
({
"1"
,
"C"
});
for
(
std
::
size_t
i
=
0
;
i
<
input_shape
.
size
()
-
2
;
++
i
)
{
broadcast_gamma
.
add_indices
({
"0"
}).
add_dims
({
"1"
});
broadcast_dgamma
.
add_indices
({
"0"
}).
add_dims
({
"1"
});
broadcast_dbeta
.
add_indices
({
"0"
}).
add_dims
({
"1"
});
}
std
::
ostringstream
reduction_dims
;
reduction_dims
<<
"("
<<
"N"
;
for
(
std
::
size_t
i
=
3
;
i
<
input_shape
.
size
()
+
1
;
++
i
)
{
reduction_dims
<<
" * X"
<<
i
;
}
reduction_dims
<<
")"
;
// Body
f
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
set
(
builder
::
ContractionOutput
{
"BatchMeanNumerator"
}
.
add_indices
({
"0"
,
"c"
,
"0"
,
"0"
})
.
add_dims
({
"1"
,
"C"
,
"1"
,
"1"
}))
.
set
(
builder
::
ContractionInput
{
"Input"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)));
f
.
add
(
builder
::
Elementwise
{
"BatchMean"
,
"BatchMeanNumerator / "
+
reduction_dims
.
str
()});
f
.
add
(
builder
::
Elementwise
{
"NegBatchMean"
,
"-BatchMean"
});
f
.
add
(
builder
::
BinaryContraction
{
"="
,
"+"
}
.
set
(
builder
::
ContractionOutput
{
"Deviation"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)
.
add_dims
({
"N"
,
"C"
})
.
add_dims
(
"X"
,
3
,
input_shape
.
size
()
+
1
))
.
set_lhs
(
builder
::
ContractionInput
{
"Input"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
))
.
set_rhs
(
builder
::
ContractionInput
{
"NegBatchMean"
}.
add_indices
({
"0"
,
"c"
,
"0"
,
"0"
})));
f
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
.
set
(
builder
::
ContractionOutput
{
"BatchVarNumerator"
}
.
add_indices
({
"0"
,
"c"
,
"0"
,
"0"
})
.
add_dims
({
"1"
,
"C"
,
"1"
,
"1"
}))
.
set_lhs
(
builder
::
ContractionInput
{
"Deviation"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
))
.
set_rhs
(
builder
::
ContractionInput
{
"Deviation"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)));
f
.
add
(
builder
::
Elementwise
{
"BatchVar"
,
"BatchVarNumerator / "
+
reduction_dims
.
str
()});
f
.
add
(
builder
::
Elementwise
{
"BatchStdDev"
,
"sqrt(BatchVar + "
+
epsilon
+
")"
});
f
.
add
(
builder
::
Elementwise
{
"NormedInput"
,
"(Input - BatchMean) / BatchStdDev"
});
// Body
f
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
set
(
builder
::
ContractionOutput
{
"BatchMeanNumerator"
}
.
add_indices
({
"0"
,
"c"
,
"0"
,
"0"
})
.
add_dims
({
"1"
,
"C"
,
"1"
,
"1"
}))
.
set
(
builder
::
ContractionInput
{
"Input"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)));
f
.
add
(
builder
::
Elementwise
{
"BatchMean"
,
"BatchMeanNumerator / "
+
reduction_dims
.
str
()});
f
.
add
(
builder
::
Elementwise
{
"NegBatchMean"
,
"-BatchMean"
});
f
.
add
(
builder
::
BinaryContraction
{
"="
,
"+"
}
.
set
(
builder
::
ContractionOutput
{
"Deviation"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)
.
add_dims
({
"N"
,
"C"
})
.
add_dims
(
"X"
,
3
,
input_shape
.
size
()
+
1
))
.
set_lhs
(
builder
::
ContractionInput
{
"Input"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
))
.
set_rhs
(
builder
::
ContractionInput
{
"NegBatchMean"
}.
add_indices
(
{
"0"
,
"c"
,
"0"
,
"0"
})));
f
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
.
set
(
builder
::
ContractionOutput
{
"BatchVarNumerator"
}
.
add_indices
({
"0"
,
"c"
,
"0"
,
"0"
})
.
add_dims
({
"1"
,
"C"
,
"1"
,
"1"
}))
.
set_lhs
(
builder
::
ContractionInput
{
"Deviation"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
))
.
set_rhs
(
builder
::
ContractionInput
{
"Deviation"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)));
f
.
add
(
builder
::
Elementwise
{
"BatchVar"
,
"BatchVarNumerator / "
+
reduction_dims
.
str
()});
f
.
add
(
builder
::
Elementwise
{
"BatchStdDev"
,
"sqrt(BatchVar + "
+
epsilon
+
")"
});
f
.
add
(
builder
::
Elementwise
{
"NormedInput"
,
"(Input - BatchMean) / BatchStdDev"
});
f
.
add
(
builder
::
Elementwise
{
"ZeroedInput"
,
"Input - BatchMean"
});
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
broadcast_gamma
)
.
set
(
builder
::
ContractionInput
{
"Gamma"
}.
add_indices
({
"c"
})));
f
.
add
(
builder
::
Elementwise
{
"DNormedInput"
,
"DOutput * BroadcastGamma"
});
f
.
add
(
builder
::
Elementwise
{
"ZeroedInput"
,
"Input - BatchMean"
});
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
broadcast_gamma
)
.
set
(
builder
::
ContractionInput
{
"Gamma"
}.
add_indices
({
"c"
})));
f
.
add
(
builder
::
Elementwise
{
"DNormedInput"
,
"DOutput * BroadcastGamma"
});
f
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
set
(
builder
::
ContractionOutput
{
"SumDOutput"
}.
add_indices
({
"c"
}).
add_dims
({
"C"
}))
.
set
(
builder
::
ContractionInput
{
"DOutput"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)));
f
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
.
set
(
builder
::
ContractionOutput
{
"DGamma"
}.
add_indices
({
"c"
}).
add_dims
({
"C"
}))
.
set_lhs
(
builder
::
ContractionInput
{
"DOutput"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
))
.
set_rhs
(
builder
::
ContractionInput
{
"NormedInput"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)));
f
.
add
(
builder
::
Elementwise
{
"DBeta"
,
"SumDOutput"
});
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
broadcast_dgamma
)
.
set
(
builder
::
ContractionInput
{
"DGamma"
}.
add_indices
({
"c"
})));
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
broadcast_dbeta
)
.
set
(
builder
::
ContractionInput
{
"DBeta"
}.
add_indices
({
"c"
})));
f
.
add
(
builder
::
Elementwise
{
"DInput"
,
"(BroadcastGamma / BatchStdDev) * (DOutput - "
"(NormedInput * BroadcastDGamma + BroadcastDBeta) / ("
+
reduction_dims
.
str
()
+
"))"
});
f
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
set
(
builder
::
ContractionOutput
{
"SumDOutput"
}.
add_indices
({
"c"
}).
add_dims
(
{
"C"
}))
.
set
(
builder
::
ContractionInput
{
"DOutput"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)));
f
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
.
set
(
builder
::
ContractionOutput
{
"DGamma"
}.
add_indices
({
"c"
}).
add_dims
(
{
"C"
}))
.
set_lhs
(
builder
::
ContractionInput
{
"DOutput"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
))
.
set_rhs
(
builder
::
ContractionInput
{
"NormedInput"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)));
f
.
add
(
builder
::
Elementwise
{
"DBeta"
,
"SumDOutput"
});
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
broadcast_dgamma
)
.
set
(
builder
::
ContractionInput
{
"DGamma"
}.
add_indices
({
"c"
})));
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
broadcast_dbeta
)
.
set
(
builder
::
ContractionInput
{
"DBeta"
}.
add_indices
({
"c"
})));
f
.
add
(
builder
::
Elementwise
{
"DInput"
,
"(BroadcastGamma / BatchStdDev) * (DOutput - "
"(NormedInput * BroadcastDGamma + BroadcastDBeta) / ("
+
reduction_dims
.
str
()
+
"))"
});
// Return results
auto
app
=
f
.
finalize
();
set_output
(
0
,
app
.
get_output
(
0
));
set_output
(
1
,
app
.
get_output
(
1
));
set_output
(
2
,
app
.
get_output
(
2
));
}
// Return results
auto
app
=
f
.
finalize
();
set_output
(
0
,
app
.
get_output
(
0
));
set_output
(
1
,
app
.
get_output
(
1
));
set_output
(
2
,
app
.
get_output
(
2
));
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
BatchNormInference
>::
Registration
register_batch_norm_inference
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
BatchNormTraining
>::
Registration
register_batch_norm_training
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
BatchNormTrainingBackprop
>::
Registration
register_batch_norm_training_backprop
;
namespace
{
Impl
<
op
::
BatchNormInference
>::
Registration
register_batch_norm_inference
;
Impl
<
op
::
BatchNormTraining
>::
Registration
register_batch_norm_training
;
Impl
<
op
::
BatchNormTrainingBackprop
>::
Registration
register_batch_norm_training_backprop
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_comparison.cpp
View file @
61df6725
...
...
@@ -24,132 +24,141 @@
#include "ngraph/op/not_equal.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Equal performs a simple elementwise equality.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Equal
>::
operator
()()
namespace
ngraph
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
,
TensorContents
::
LOGICAL
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
,
TensorContents
::
LOGICAL
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A == B"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
namespace
runtime
{
namespace
plaidml
{
// Equal performs a simple elementwise equality.
template
<>
void
Impl
<
op
::
Equal
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
,
TensorContents
::
LOGICAL
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
,
TensorContents
::
LOGICAL
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A == B"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
// Greater performs a simple elementwise greater-than comparison.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Greater
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A > B"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
// Greater performs a simple elementwise greater-than comparison.
template
<>
void
Impl
<
op
::
Greater
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A > B"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
// GreaterEq performs a simple elementwise greater-than-or-equal-to comparison.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
GreaterEq
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A >= B"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
// GreaterEq performs a simple elementwise greater-than-or-equal-to comparison.
template
<>
void
Impl
<
op
::
GreaterEq
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A >= B"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
// Less performs a simple elementwise less-than comparison.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Less
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A < B"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
// Less performs a simple elementwise less-than comparison.
template
<>
void
Impl
<
op
::
Less
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A < B"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
// LessEq performs a simple elementwise less-than-or-equal-to comparison.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
LessEq
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A <= B"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
// LessEq performs a simple elementwise less-than-or-equal-to comparison.
template
<>
void
Impl
<
op
::
LessEq
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A <= B"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
// Maximum performs a simple elementwise maximum.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Maximum
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"max(A, B)"
})
.
finalize
());
}
// Maximum performs a simple elementwise maximum.
template
<>
void
Impl
<
op
::
Maximum
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"max(A, B)"
})
.
finalize
());
}
// Minimum performs a simple elementwise minimum.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Minimum
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"min(A, B)"
})
.
finalize
());
}
// Minimum performs a simple elementwise minimum.
template
<>
void
Impl
<
op
::
Minimum
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"min(A, B)"
})
.
finalize
());
}
// NotEqual performs a simple elementwise not-equality.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
NotEqual
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
,
TensorContents
::
LOGICAL
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
,
TensorContents
::
LOGICAL
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A != B"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
// NotEqual performs a simple elementwise not-equality.
template
<>
void
Impl
<
op
::
NotEqual
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
,
TensorContents
::
LOGICAL
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
,
TensorContents
::
LOGICAL
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A != B"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Equal
>::
Registration
register_equal
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Greater
>::
Registration
register_greater
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
GreaterEq
>::
Registration
register_greater_eq
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Less
>::
Registration
register_less
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
LessEq
>::
Registration
register_less_eq
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Maximum
>::
Registration
register_maximum
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Minimum
>::
Registration
register_minimum
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
NotEqual
>::
Registration
register_not_equal
;
namespace
{
Impl
<
op
::
Equal
>::
Registration
register_equal
;
Impl
<
op
::
Greater
>::
Registration
register_greater
;
Impl
<
op
::
GreaterEq
>::
Registration
register_greater_eq
;
Impl
<
op
::
Less
>::
Registration
register_less
;
Impl
<
op
::
LessEq
>::
Registration
register_less_eq
;
Impl
<
op
::
Maximum
>::
Registration
register_maximum
;
Impl
<
op
::
Minimum
>::
Registration
register_minimum
;
Impl
<
op
::
NotEqual
>::
Registration
register_not_equal
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_concat.cpp
View file @
61df6725
...
...
@@ -17,87 +17,101 @@
#include "ngraph/op/concat.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Concat views a tensor as a new type.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Concat
>::
operator
()()
namespace
ngraph
{
check_outputs
(
1
);
auto
f
=
start_tile_function
();
f
.
add
(
builder
::
Output
{
"O"
});
std
::
size_t
dim_count
=
op
().
get_shape
().
size
();
std
::
ostringstream
offset
;
std
::
ostringstream
oexpr
;
std
::
ostringstream
concat_dsize
;
bool
saw_non_zero_tensor
=
false
;
for
(
std
::
size_t
iidx
=
0
;
iidx
<
op
().
get_inputs
().
size
();
++
iidx
)
namespace
runtime
{
if
(
!
shape_size
(
op
().
get_input_shape
(
iidx
)))
{
continue
;
}
if
(
saw_non_zero_tensor
)
namespace
plaidml
{
concat_dsize
<<
"+"
;
}
saw_non_zero_tensor
=
true
;
concat_dsize
<<
"I"
<<
iidx
<<
"_D"
<<
op
().
get_concatenation_axis
();
}
// Concat views a tensor as a new type.
template
<>
void
Impl
<
op
::
Concat
>::
operator
()()
{
check_outputs
(
1
);
saw_non_zero_tensor
=
false
;
for
(
std
::
size_t
iidx
=
0
;
iidx
<
op
().
get_inputs
().
size
();
++
iidx
)
{
if
(
!
shape_size
(
op
().
get_input_shape
(
iidx
)))
{
continue
;
}
std
::
string
sidx
{
std
::
to_string
(
iidx
)};
f
.
add
(
builder
::
Input
{
op_input
(
iidx
),
"I"
+
sidx
}.
add_dims
(
"I"
+
sidx
+
"_D"
,
0
,
dim_count
));
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"E"
+
sidx
}
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_count
;
++
idx
)
{
std
::
ostringstream
s
;
if
(
idx
==
op
().
get_concatenation_axis
())
{
out
=
concat_dsize
.
str
();
}
else
{
s
<<
"I"
<<
iidx
<<
"_D"
<<
idx
;
out
=
s
.
str
();
}
}
})
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_count
;
++
idx
)
{
std
::
ostringstream
s
;
s
<<
"d"
<<
idx
;
if
(
saw_non_zero_tensor
&&
idx
==
op
().
get_concatenation_axis
())
{
s
<<
" + "
<<
offset
.
str
();
}
out
=
s
.
str
();
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
+
sidx
}.
add_indices
(
"d"
,
0
,
dim_count
)));
if
(
saw_non_zero_tensor
)
{
oexpr
<<
" + "
;
offset
<<
" + "
;
}
oexpr
<<
"E"
<<
sidx
;
offset
<<
"I"
<<
iidx
<<
"_D"
<<
op
().
get_concatenation_axis
();
saw_non_zero_tensor
=
true
;
}
f
.
add
(
builder
::
Elementwise
{
"O"
,
oexpr
.
str
()});
auto
f
=
start_tile_function
();
f
.
add
(
builder
::
Output
{
"O"
});
std
::
size_t
dim_count
=
op
().
get_shape
().
size
();
std
::
ostringstream
offset
;
std
::
ostringstream
oexpr
;
std
::
ostringstream
concat_dsize
;
bool
saw_non_zero_tensor
=
false
;
for
(
std
::
size_t
iidx
=
0
;
iidx
<
op
().
get_inputs
().
size
();
++
iidx
)
{
if
(
!
shape_size
(
op
().
get_input_shape
(
iidx
)))
{
continue
;
}
if
(
saw_non_zero_tensor
)
{
concat_dsize
<<
"+"
;
}
saw_non_zero_tensor
=
true
;
concat_dsize
<<
"I"
<<
iidx
<<
"_D"
<<
op
().
get_concatenation_axis
();
}
set_output
(
f
.
finalize
());
}
saw_non_zero_tensor
=
false
;
for
(
std
::
size_t
iidx
=
0
;
iidx
<
op
().
get_inputs
().
size
();
++
iidx
)
{
if
(
!
shape_size
(
op
().
get_input_shape
(
iidx
)))
{
continue
;
}
std
::
string
sidx
{
std
::
to_string
(
iidx
)};
f
.
add
(
builder
::
Input
{
op_input
(
iidx
),
"I"
+
sidx
}.
add_dims
(
"I"
+
sidx
+
"_D"
,
0
,
dim_count
));
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"E"
+
sidx
}
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_count
;
++
idx
)
{
std
::
ostringstream
s
;
if
(
idx
==
op
().
get_concatenation_axis
())
{
out
=
concat_dsize
.
str
();
}
else
{
s
<<
"I"
<<
iidx
<<
"_D"
<<
idx
;
out
=
s
.
str
();
}
}
})
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_count
;
++
idx
)
{
std
::
ostringstream
s
;
s
<<
"d"
<<
idx
;
if
(
saw_non_zero_tensor
&&
idx
==
op
().
get_concatenation_axis
())
{
s
<<
" + "
<<
offset
.
str
();
}
out
=
s
.
str
();
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
+
sidx
}.
add_indices
(
"d"
,
0
,
dim_count
)));
if
(
saw_non_zero_tensor
)
{
oexpr
<<
" + "
;
offset
<<
" + "
;
}
oexpr
<<
"E"
<<
sidx
;
offset
<<
"I"
<<
iidx
<<
"_D"
<<
op
().
get_concatenation_axis
();
saw_non_zero_tensor
=
true
;
}
f
.
add
(
builder
::
Elementwise
{
"O"
,
oexpr
.
str
()});
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Concat
>::
Registration
register_concat
;
set_output
(
f
.
finalize
());
}
namespace
{
Impl
<
op
::
Concat
>::
Registration
register_concat
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_convert.cpp
View file @
61df6725
...
...
@@ -18,21 +18,31 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp"
// Convert views a tensor as a new type.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Convert
>::
operator
()()
namespace
ngraph
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
tile_converter
(
"I"
,
to_plaidml
(
op
().
get_convert_element_type
()))})
.
finalize
());
}
namespace
runtime
{
namespace
plaidml
{
// Convert views a tensor as a new type.
template
<>
void
Impl
<
op
::
Convert
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
tile_converter
(
"I"
,
to_plaidml
(
op
().
get_convert_element_type
()))})
.
finalize
());
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Convert
>::
Registration
register_convert
;
namespace
{
Impl
<
op
::
Convert
>::
Registration
register_convert
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_convolution.cpp
View file @
61df6725
...
...
@@ -50,234 +50,232 @@ namespace ngraph
std
::
size_t
output_channel_axis_result
,
bool
rotate_filter
);
};
}
}
}
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
Convolution
>
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ConvolutionImpl
<
ngraph
::
op
::
Convolution
>
;
};
template
<>
struct
ParentImpl
<
op
::
Convolution
>
{
using
Type
=
ConvolutionImpl
<
op
::
Convolution
>
;
};
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
ConvolutionBackpropFilters
>
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ConvolutionImpl
<
ngraph
::
op
::
ConvolutionBackpropFilters
>
;
};
template
<>
struct
ParentImpl
<
op
::
ConvolutionBackpropFilters
>
{
using
Type
=
ConvolutionImpl
<
op
::
ConvolutionBackpropFilters
>
;
};
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
ConvolutionBackpropData
>
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ConvolutionImpl
<
ngraph
::
op
::
ConvolutionBackpropData
>
;
};
template
<>
struct
ParentImpl
<
op
::
ConvolutionBackpropData
>
{
using
Type
=
ConvolutionImpl
<
op
::
ConvolutionBackpropData
>
;
};
// Convolution implements a standard ML convolultion, with optional striding, padding, and dilation.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Convolution
>::
operator
()()
{
this
->
check_inputs
(
2
);
this
->
check_outputs
(
1
);
// Convolution implements a standard ML convolultion, with optional striding, padding, and dilation.
template
<>
void
Impl
<
op
::
Convolution
>::
operator
()()
{
this
->
check_inputs
(
2
);
this
->
check_outputs
(
1
);
LogConvolution
(
op_input
(
0
),
op_input
(
1
),
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
,
op
().
get_window_movement_strides
(),
op
().
get_window_dilation_strides
(),
op
().
get_padding_below
(),
op
().
get_padding_above
(),
op
().
get_data_dilation_strides
(),
0
,
1
,
1
,
0
,
0
,
1
,
false
);
LogConvolution
(
op_input
(
0
),
op_input
(
1
),
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
,
op
().
get_window_movement_strides
(),
op
().
get_window_dilation_strides
(),
op
().
get_padding_below
(),
op
().
get_padding_above
(),
op
().
get_data_dilation_strides
(),
0
,
1
,
1
,
0
,
0
,
1
,
false
);
const
auto
&
image
=
op_input
(
0
);
const
auto
&
filter
=
op_input
(
1
);
auto
image_dims
=
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
;
const
auto
&
padding_above
=
op
().
get_padding_above
();
const
auto
&
padding_below
=
op
().
get_padding_below
();
const
auto
&
strides
=
op
().
get_window_movement_strides
();
const
auto
&
filter_dilation
=
op
().
get_window_dilation_strides
();
const
auto
&
data_dilation
=
op
().
get_data_dilation_strides
();
const
auto
&
image
=
op_input
(
0
);
const
auto
&
filter
=
op_input
(
1
);
auto
image_dims
=
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
;
const
auto
&
padding_above
=
op
().
get_padding_above
();
const
auto
&
padding_below
=
op
().
get_padding_below
();
const
auto
&
strides
=
op
().
get_window_movement_strides
();
const
auto
&
filter_dilation
=
op
().
get_window_dilation_strides
();
const
auto
&
data_dilation
=
op
().
get_data_dilation_strides
();
ConvPoolFormatter
cpf
(
image_dims
,
padding_below
,
padding_above
,
strides
,
filter_dilation
,
data_dilation
,
ConvPoolFormatter
::
OpType
::
Conv
,
ConvPoolFormatter
::
DerivType
::
None
);
ConvPoolFormatter
cpf
(
image_dims
,
padding_below
,
padding_above
,
strides
,
filter_dilation
,
data_dilation
,
ConvPoolFormatter
::
OpType
::
Conv
,
ConvPoolFormatter
::
DerivType
::
None
);
this
->
set_output
(
start_tile_function
()
.
add
(
cpf
.
I_in_header
(
image
))
.
add
(
cpf
.
F_in_header
(
filter
))
.
add
(
cpf
.
O_out_header
())
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
.
set
(
cpf
.
O_out_body
())
.
set_lhs
(
cpf
.
I_in_body
())
.
set_rhs
(
cpf
.
F_in_body
()))
.
finalize
());
}
this
->
set_output
(
start_tile_function
()
.
add
(
cpf
.
I_in_header
(
image
))
.
add
(
cpf
.
F_in_header
(
filter
))
.
add
(
cpf
.
O_out_header
())
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
.
set
(
cpf
.
O_out_body
())
.
set_lhs
(
cpf
.
I_in_body
())
.
set_rhs
(
cpf
.
F_in_body
()))
.
finalize
());
}
// ConvolutionBackpropFilters implements the derivative of a convolution with respect to its filter
// input.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ConvolutionBackpropFilters
>::
operator
()()
{
this
->
check_inputs
(
2
);
this
->
check_outputs
(
1
);
// ConvolutionBackpropFilters implements the derivative of a convolution with respect to its filter
// input.
template
<>
void
Impl
<
op
::
ConvolutionBackpropFilters
>::
operator
()()
{
this
->
check_inputs
(
2
);
this
->
check_outputs
(
1
);
LogConvolution
(
op_input
(
0
),
op_input
(
1
),
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
,
op
().
get_window_movement_strides_backward
(),
op
().
get_window_dilation_strides_backward
(),
op
().
get_padding_below_backward
(),
op
().
get_padding_above_backward
(),
op
().
get_data_dilation_strides_backward
(),
1
,
0
,
0
,
1
,
1
,
0
,
false
);
LogConvolution
(
op_input
(
0
),
op_input
(
1
),
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
,
op
().
get_window_movement_strides_backward
(),
op
().
get_window_dilation_strides_backward
(),
op
().
get_padding_below_backward
(),
op
().
get_padding_above_backward
(),
op
().
get_data_dilation_strides_backward
(),
1
,
0
,
0
,
1
,
1
,
0
,
false
);
const
auto
&
image
=
op_input
(
0
);
const
auto
&
output
=
op_input
(
1
);
auto
image_dims
=
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
;
const
auto
&
padding_above
=
op
().
get_padding_above_forward
();
const
auto
&
padding_below
=
op
().
get_padding_below_forward
();
const
auto
&
strides
=
op
().
get_window_movement_strides_forward
();
const
auto
&
filter_dilation
=
op
().
get_window_dilation_strides_forward
();
const
auto
&
data_dilation
=
op
().
get_data_dilation_strides_forward
();
const
auto
&
filters_shape
=
op
().
get_filters_shape
();
const
auto
&
image
=
op_input
(
0
);
const
auto
&
output
=
op_input
(
1
);
auto
image_dims
=
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
;
const
auto
&
padding_above
=
op
().
get_padding_above_forward
();
const
auto
&
padding_below
=
op
().
get_padding_below_forward
();
const
auto
&
strides
=
op
().
get_window_movement_strides_forward
();
const
auto
&
filter_dilation
=
op
().
get_window_dilation_strides_forward
();
const
auto
&
data_dilation
=
op
().
get_data_dilation_strides_forward
();
const
auto
&
filters_shape
=
op
().
get_filters_shape
();
ConvPoolFormatter
cpf
(
image_dims
,
padding_below
,
padding_above
,
strides
,
filter_dilation
,
data_dilation
,
ConvPoolFormatter
::
OpType
::
Conv
,
ConvPoolFormatter
::
DerivType
::
Filter
,
filters_shape
);
ConvPoolFormatter
cpf
(
image_dims
,
padding_below
,
padding_above
,
strides
,
filter_dilation
,
data_dilation
,
ConvPoolFormatter
::
OpType
::
Conv
,
ConvPoolFormatter
::
DerivType
::
Filter
,
filters_shape
);
this
->
set_output
(
start_tile_function
()
.
add
(
cpf
.
I_in_header
(
image
))
.
add
(
cpf
.
O_in_header
(
output
))
.
add
(
cpf
.
F_out_header
())
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
.
set
(
cpf
.
F_out_body
())
.
set_lhs
(
cpf
.
O_in_body
())
.
set_rhs
(
cpf
.
I_in_body
()))
.
finalize
());
}
this
->
set_output
(
start_tile_function
()
.
add
(
cpf
.
I_in_header
(
image
))
.
add
(
cpf
.
O_in_header
(
output
))
.
add
(
cpf
.
F_out_header
())
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
.
set
(
cpf
.
F_out_body
())
.
set_lhs
(
cpf
.
O_in_body
())
.
set_rhs
(
cpf
.
I_in_body
()))
.
finalize
());
}
// ConvolutionBackpropData implements the derivative of a convolution with respect to its data
// input.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ConvolutionBackpropData
>::
operator
()()
{
this
->
check_inputs
(
2
);
this
->
check_outputs
(
1
);
// ConvolutionBackpropData implements the derivative of a convolution with respect to its data
// input.
template
<>
void
Impl
<
op
::
ConvolutionBackpropData
>::
operator
()()
{
this
->
check_inputs
(
2
);
this
->
check_outputs
(
1
);
LogConvolution
(
op_input
(
0
),
op_input
(
1
),
op
().
get_inputs
()[
1
].
get_shape
().
size
()
-
2
,
op
().
get_window_movement_strides_backward
(),
op
().
get_window_dilation_strides_backward
(),
op
().
get_padding_below_backward
(),
op
().
get_padding_above_backward
(),
op
().
get_data_dilation_strides_backward
(),
0
,
1
,
0
,
1
,
0
,
1
,
true
);
LogConvolution
(
op_input
(
0
),
op_input
(
1
),
op
().
get_inputs
()[
1
].
get_shape
().
size
()
-
2
,
op
().
get_window_movement_strides_backward
(),
op
().
get_window_dilation_strides_backward
(),
op
().
get_padding_below_backward
(),
op
().
get_padding_above_backward
(),
op
().
get_data_dilation_strides_backward
(),
0
,
1
,
0
,
1
,
0
,
1
,
true
);
auto
image_dims
=
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
;
const
auto
&
filter
=
op_input
(
0
);
const
auto
&
output
=
op_input
(
1
);
const
auto
&
padding_above
=
op
().
get_padding_above_forward
();
const
auto
&
padding_below
=
op
().
get_padding_below_forward
();
const
auto
&
strides
=
op
().
get_window_movement_strides_forward
();
const
auto
&
filter_dilation
=
op
().
get_window_dilation_strides_forward
();
const
auto
&
data_dilation
=
op
().
get_data_dilation_strides_forward
();
const
auto
&
data_batch_shape
=
op
().
get_data_batch_shape
();
auto
image_dims
=
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
;
const
auto
&
filter
=
op_input
(
0
);
const
auto
&
output
=
op_input
(
1
);
const
auto
&
padding_above
=
op
().
get_padding_above_forward
();
const
auto
&
padding_below
=
op
().
get_padding_below_forward
();
const
auto
&
strides
=
op
().
get_window_movement_strides_forward
();
const
auto
&
filter_dilation
=
op
().
get_window_dilation_strides_forward
();
const
auto
&
data_dilation
=
op
().
get_data_dilation_strides_forward
();
const
auto
&
data_batch_shape
=
op
().
get_data_batch_shape
();
ConvPoolFormatter
cpf
(
image_dims
,
padding_below
,
padding_above
,
strides
,
filter_dilation
,
data_dilation
,
ConvPoolFormatter
::
OpType
::
Conv
,
ConvPoolFormatter
::
DerivType
::
Data
,
data_batch_shape
);
ConvPoolFormatter
cpf
(
image_dims
,
padding_below
,
padding_above
,
strides
,
filter_dilation
,
data_dilation
,
ConvPoolFormatter
::
OpType
::
Conv
,
ConvPoolFormatter
::
DerivType
::
Data
,
data_batch_shape
);
this
->
set_output
(
start_tile_function
()
.
add
(
cpf
.
F_in_header
(
filter
))
.
add
(
cpf
.
O_in_header
(
output
))
.
add
(
cpf
.
I_out_header
())
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
.
set
(
cpf
.
I_out_body
())
.
set_lhs
(
cpf
.
O_in_body
())
.
set_rhs
(
cpf
.
F_in_body
()))
.
finalize
());
}
this
->
set_output
(
start_tile_function
()
.
add
(
cpf
.
F_in_header
(
filter
))
.
add
(
cpf
.
O_in_header
(
output
))
.
add
(
cpf
.
I_out_header
())
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
.
set
(
cpf
.
I_out_body
())
.
set_lhs
(
cpf
.
O_in_body
())
.
set_rhs
(
cpf
.
F_in_body
()))
.
finalize
());
}
template
<
typename
O
>
inline
void
ngraph
::
runtime
::
plaidml
::
ConvolutionImpl
<
O
>::
LogConvolution
(
vertexai
::
plaidml
::
variable
image
,
vertexai
::
plaidml
::
variable
filter
,
std
::
size_t
image_dims
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_dilation_strides
,
const
CoordinateDiff
&
padding_below
,
const
CoordinateDiff
&
padding_above
,
const
Strides
&
data_dilation_strides
,
std
::
size_t
batch_axis_data
,
std
::
size_t
input_channel_axis_data
,
std
::
size_t
input_channel_axis_filters
,
std
::
size_t
output_channel_axis_filters
,
std
::
size_t
batch_axis_result
,
std
::
size_t
output_channel_axis_result
,
bool
rotate_filter
)
{
this
->
check_inputs
(
2
);
this
->
check_outputs
(
1
);
template
<
typename
O
>
inline
void
ConvolutionImpl
<
O
>::
LogConvolution
(
vertexai
::
plaidml
::
variable
image
,
vertexai
::
plaidml
::
variable
filter
,
std
::
size_t
image_dims
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_dilation_strides
,
const
CoordinateDiff
&
padding_below
,
const
CoordinateDiff
&
padding_above
,
const
Strides
&
data_dilation_strides
,
std
::
size_t
batch_axis_data
,
std
::
size_t
input_channel_axis_data
,
std
::
size_t
input_channel_axis_filters
,
std
::
size_t
output_channel_axis_filters
,
std
::
size_t
batch_axis_result
,
std
::
size_t
output_channel_axis_result
,
bool
rotate_filter
)
{
this
->
check_inputs
(
2
);
this
->
check_outputs
(
1
);
NGRAPH_DEBUG
<<
"image_dims: "
<<
image_dims
;
NGRAPH_DEBUG
<<
"first_dims: "
<<
this
->
op
().
get_inputs
()[
0
].
get_shape
();
NGRAPH_DEBUG
<<
"second_dims: "
<<
this
->
op
().
get_inputs
()[
1
].
get_shape
();
NGRAPH_DEBUG
<<
"output_dims: "
<<
this
->
op
().
get_outputs
()[
0
].
get_shape
();
NGRAPH_DEBUG
<<
"padding_below: "
<<
padding_below
;
NGRAPH_DEBUG
<<
"padding_above: "
<<
padding_above
;
NGRAPH_DEBUG
<<
"window_movement_strides: "
<<
window_movement_strides
;
NGRAPH_DEBUG
<<
"window_dilation_strides: "
<<
window_dilation_strides
;
NGRAPH_DEBUG
<<
"data_dilation_strides:"
<<
data_dilation_strides
;
NGRAPH_DEBUG
<<
"batch_axis_data: "
<<
batch_axis_data
;
NGRAPH_DEBUG
<<
"input_channel_axis_data: "
<<
input_channel_axis_data
;
NGRAPH_DEBUG
<<
"input_channel_axis_filters: "
<<
input_channel_axis_filters
;
NGRAPH_DEBUG
<<
"output_channel_axis_filters: "
<<
output_channel_axis_filters
;
NGRAPH_DEBUG
<<
"batch_axis_result: "
<<
batch_axis_result
;
NGRAPH_DEBUG
<<
"output_channel_axis_result: "
<<
output_channel_axis_result
;
NGRAPH_DEBUG
<<
"rotate_filter: "
<<
rotate_filter
;
}
NGRAPH_DEBUG
<<
"image_dims: "
<<
image_dims
;
NGRAPH_DEBUG
<<
"first_dims: "
<<
this
->
op
().
get_inputs
()[
0
].
get_shape
();
NGRAPH_DEBUG
<<
"second_dims: "
<<
this
->
op
().
get_inputs
()[
1
].
get_shape
();
NGRAPH_DEBUG
<<
"output_dims: "
<<
this
->
op
().
get_outputs
()[
0
].
get_shape
();
NGRAPH_DEBUG
<<
"padding_below: "
<<
padding_below
;
NGRAPH_DEBUG
<<
"padding_above: "
<<
padding_above
;
NGRAPH_DEBUG
<<
"window_movement_strides: "
<<
window_movement_strides
;
NGRAPH_DEBUG
<<
"window_dilation_strides: "
<<
window_dilation_strides
;
NGRAPH_DEBUG
<<
"data_dilation_strides:"
<<
data_dilation_strides
;
NGRAPH_DEBUG
<<
"batch_axis_data: "
<<
batch_axis_data
;
NGRAPH_DEBUG
<<
"input_channel_axis_data: "
<<
input_channel_axis_data
;
NGRAPH_DEBUG
<<
"input_channel_axis_filters: "
<<
input_channel_axis_filters
;
NGRAPH_DEBUG
<<
"output_channel_axis_filters: "
<<
output_channel_axis_filters
;
NGRAPH_DEBUG
<<
"batch_axis_result: "
<<
batch_axis_result
;
NGRAPH_DEBUG
<<
"output_channel_axis_result: "
<<
output_channel_axis_result
;
NGRAPH_DEBUG
<<
"rotate_filter: "
<<
rotate_filter
;
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Convolution
>::
Registration
register_convolution
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ConvolutionBackpropFilters
>::
Registration
register_convolution_backprop_filters
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ConvolutionBackpropData
>::
Registration
register_convolution_backprop_data
;
namespace
{
Impl
<
op
::
Convolution
>::
Registration
register_convolution
;
Impl
<
op
::
ConvolutionBackpropFilters
>::
Registration
register_convolution_backprop_filters
;
Impl
<
op
::
ConvolutionBackpropData
>::
Registration
register_convolution_backprop_data
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_dot.cpp
View file @
61df6725
...
...
@@ -20,50 +20,60 @@
#include "ngraph/op/dot.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Dot is a generalized dot product operation -- scalar-tensor,
// matrix-vector, and matrix multiplication.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Dot
>::
operator
()()
namespace
ngraph
{
check_inputs
(
2
);
check_outputs
(
1
);
namespace
runtime
{
namespace
plaidml
{
// Dot is a generalized dot product operation -- scalar-tensor,
// matrix-vector, and matrix multiplication.
template
<>
void
Impl
<
op
::
Dot
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
auto
l_dim_limit
=
op
().
get_inputs
()[
0
].
get_shape
().
size
();
auto
r_dim_limit
=
op
().
get_inputs
()[
1
].
get_shape
().
size
();
auto
reduce_limit
=
op
().
get_reduction_axes_count
();
auto
l_dim_mac
=
l_dim_limit
-
reduce_limit
;
auto
r_dim_mic
=
reduce_limit
;
auto
l_dim_limit
=
op
().
get_inputs
()[
0
].
get_shape
().
size
();
auto
r_dim_limit
=
op
().
get_inputs
()[
1
].
get_shape
().
size
();
auto
reduce_limit
=
op
().
get_reduction_axes_count
();
auto
l_dim_mac
=
l_dim_limit
-
reduce_limit
;
auto
r_dim_mic
=
reduce_limit
;
NGRAPH_DEBUG
<<
"l_dim_limit="
<<
l_dim_limit
;
NGRAPH_DEBUG
<<
"r_dim_limit="
<<
r_dim_limit
;
NGRAPH_DEBUG
<<
"reduce_limit="
<<
reduce_limit
;
NGRAPH_DEBUG
<<
"l_dim_mac="
<<
l_dim_mac
;
NGRAPH_DEBUG
<<
"r_dim_mic="
<<
r_dim_mic
;
NGRAPH_DEBUG
<<
"l_dim_limit="
<<
l_dim_limit
;
NGRAPH_DEBUG
<<
"r_dim_limit="
<<
r_dim_limit
;
NGRAPH_DEBUG
<<
"reduce_limit="
<<
reduce_limit
;
NGRAPH_DEBUG
<<
"l_dim_mac="
<<
l_dim_mac
;
NGRAPH_DEBUG
<<
"r_dim_mic="
<<
r_dim_mic
;
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"L"
}
.
add_dims
(
"DL"
,
1
,
l_dim_mac
+
1
)
.
add_dims
(
"DC"
,
1
,
reduce_limit
+
1
))
.
add
(
builder
::
Input
{
op_input
(
1
),
"R"
}
.
add_dims
(
"DC"
,
1
,
reduce_limit
+
1
)
.
add_dims
(
"DR"
,
r_dim_mic
+
1
,
r_dim_limit
+
1
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"dl"
,
1
,
l_dim_mac
+
1
)
.
add_indices
(
"dr"
,
r_dim_mic
+
1
,
r_dim_limit
+
1
)
.
add_dims
(
"DL"
,
1
,
l_dim_mac
+
1
)
.
add_dims
(
"DR"
,
r_dim_mic
+
1
,
r_dim_limit
+
1
))
.
set_lhs
(
builder
::
ContractionInput
{
"L"
}
.
add_indices
(
"dl"
,
1
,
l_dim_mac
+
1
)
.
add_indices
(
"dc"
,
1
,
reduce_limit
+
1
))
.
set_rhs
(
builder
::
ContractionInput
{
"R"
}
.
add_indices
(
"dc"
,
1
,
reduce_limit
+
1
)
.
add_indices
(
"dr"
,
r_dim_mic
+
1
,
r_dim_limit
+
1
)))
.
finalize
());
}
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"L"
}
.
add_dims
(
"DL"
,
1
,
l_dim_mac
+
1
)
.
add_dims
(
"DC"
,
1
,
reduce_limit
+
1
))
.
add
(
builder
::
Input
{
op_input
(
1
),
"R"
}
.
add_dims
(
"DC"
,
1
,
reduce_limit
+
1
)
.
add_dims
(
"DR"
,
r_dim_mic
+
1
,
r_dim_limit
+
1
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"dl"
,
1
,
l_dim_mac
+
1
)
.
add_indices
(
"dr"
,
r_dim_mic
+
1
,
r_dim_limit
+
1
)
.
add_dims
(
"DL"
,
1
,
l_dim_mac
+
1
)
.
add_dims
(
"DR"
,
r_dim_mic
+
1
,
r_dim_limit
+
1
))
.
set_lhs
(
builder
::
ContractionInput
{
"L"
}
.
add_indices
(
"dl"
,
1
,
l_dim_mac
+
1
)
.
add_indices
(
"dc"
,
1
,
reduce_limit
+
1
))
.
set_rhs
(
builder
::
ContractionInput
{
"R"
}
.
add_indices
(
"dc"
,
1
,
reduce_limit
+
1
)
.
add_indices
(
"dr"
,
r_dim_mic
+
1
,
r_dim_limit
+
1
)))
.
finalize
());
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Dot
>::
Registration
register_dot
;
namespace
{
Impl
<
op
::
Dot
>::
Registration
register_dot
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_function.cpp
View file @
61df6725
...
...
@@ -19,29 +19,39 @@
#include "ngraph/runtime/plaidml/plaidml_compiler.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// FunctionCall invokes a sub-function.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
FunctionCall
>::
operator
()()
namespace
ngraph
{
Build
b
;
build
()
->
compiler
->
build
(
op
().
get_functions
()[
0
],
&
b
);
vertexai
::
plaidml
::
function
f
{
b
.
composer
};
vertexai
::
plaidml
::
function
::
parameters_t
inputs
;
for
(
std
::
size_t
idx
=
0
;
idx
<
op
().
get_input_size
();
++
idx
)
namespace
runtime
{
auto
*
oitv
=
op
().
get_inputs
()[
idx
].
get_output
().
get_tensor_ptr
().
get
();
auto
*
iitv
=
b
.
func
->
get_parameters
()[
idx
]
->
get_outputs
()[
0
].
get_tensor_ptr
().
get
();
inputs
.
emplace_back
(
b
.
input_names
.
at
(
iitv
),
build
()
->
bindings
.
at
(
oitv
).
var
);
}
vertexai
::
plaidml
::
application
app
{
f
.
apply
(
inputs
)};
for
(
std
::
size_t
idx
=
0
;
idx
<
op
().
get_output_size
();
++
idx
)
{
auto
*
iotv
=
b
.
func
->
get_results
()[
idx
]
->
get_output_tensor_ptr
().
get
();
set_output
(
idx
,
app
.
get_output
(
b
.
output_names
[
iotv
]));
}
}
namespace
plaidml
{
// FunctionCall invokes a sub-function.
template
<>
void
Impl
<
op
::
FunctionCall
>::
operator
()()
{
Build
b
;
build
()
->
compiler
->
build
(
op
().
get_functions
()[
0
],
&
b
);
vertexai
::
plaidml
::
function
f
{
b
.
composer
};
vertexai
::
plaidml
::
function
::
parameters_t
inputs
;
for
(
std
::
size_t
idx
=
0
;
idx
<
op
().
get_input_size
();
++
idx
)
{
auto
*
oitv
=
op
().
get_inputs
()[
idx
].
get_output
().
get_tensor_ptr
().
get
();
auto
*
iitv
=
b
.
func
->
get_parameters
()[
idx
]
->
get_outputs
()[
0
].
get_tensor_ptr
().
get
();
inputs
.
emplace_back
(
b
.
input_names
.
at
(
iitv
),
build
()
->
bindings
.
at
(
oitv
).
var
);
}
vertexai
::
plaidml
::
application
app
{
f
.
apply
(
inputs
)};
for
(
std
::
size_t
idx
=
0
;
idx
<
op
().
get_output_size
();
++
idx
)
{
auto
*
iotv
=
b
.
func
->
get_results
()[
idx
]
->
get_output_tensor_ptr
().
get
();
set_output
(
idx
,
app
.
get_output
(
b
.
output_names
[
iotv
]));
}
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
FunctionCall
>::
Registration
register_function_call
;
namespace
{
Impl
<
op
::
FunctionCall
>::
Registration
register_function_call
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_general.cpp
View file @
61df6725
...
...
@@ -28,428 +28,458 @@
namespace
vp
=
vertexai
::
plaidml
;
// Broadcast broadcasts a tensor to a wider shape.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Broadcast
>::
operator
()()
namespace
ngraph
{
check_inputs
(
1
);
check_outputs
(
1
);
auto
in_dim_limit
=
op
().
get_inputs
()[
0
].
get_shape
().
size
();
auto
out_dim_limit
=
op
().
get_broadcast_shape
().
size
();
NGRAPH_DEBUG
<<
"Broadcast in_dim_limit: "
<<
in_dim_limit
<<
" out_dim_limit:"
<<
out_dim_limit
;
NGRAPH_DEBUG
<<
"Broadcast axes: "
<<
op
().
get_broadcast_axes
();
NGRAPH_DEBUG
<<
"Broadcast input shape: "
<<
op
().
get_input_shape
(
0
);
NGRAPH_DEBUG
<<
"Broadcast output shape: "
<<
op
().
get_broadcast_shape
();
auto
input_didx
=
in_dim_limit
;
std
::
vector
<
std
::
size_t
>
out_didxs
;
for
(
std
::
size_t
idx
=
0
;
idx
<
out_dim_limit
;
++
idx
)
{
if
(
!
op
().
get_broadcast_axes
().
count
(
idx
))
{
out_didxs
.
push_back
(
out_dim_limit
-
idx
-
1
);
}
}
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
}.
add_rdims
(
"D"
,
in_dim_limit
,
0
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_rindices
(
"o"
,
out_dim_limit
,
0
)
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
out_dim_limit
;
++
idx
)
{
if
(
op
().
get_broadcast_axes
().
count
(
idx
))
{
out
=
std
::
to_string
(
op
().
get_broadcast_shape
()[
idx
]);
}
else
{
out
=
"D"
+
std
::
to_string
(
--
input_didx
);
}
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
in_dim_limit
;
++
idx
)
{
out
=
"o"
+
std
::
to_string
(
out_didxs
[
idx
]);
}
})))
.
finalize
());
}
// Constant fills in a tensor constant.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Constant
>::
operator
()()
{
check_inputs
(
0
);
check_outputs
(
1
);
bool
output_to_result
=
false
;
for
(
const
std
::
shared_ptr
<
Node
>&
node
:
op
().
get_users
())
{
if
(
dynamic_cast
<
op
::
Result
*>
(
node
.
get
()))
{
output_to_result
=
true
;
break
;
}
}
if
(
!
op
().
get_shape
().
size
()
&&
!
output_to_result
)
namespace
runtime
{
switch
(
to_plaidml
(
op
().
get_element_type
()))
namespace
plaidml
{
case
PLAIDML_DATA_BOOLEAN
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
char
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_INT8
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
int8_t
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_INT16
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
int16_t
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_INT32
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
int32_t
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_INT64
:
set_output
(
*
static_cast
<
const
std
::
int64_t
*>
(
op
().
get_data_ptr
()));
return
;
case
PLAIDML_DATA_UINT8
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
uint8_t
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_UINT16
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
uint16_t
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_UINT32
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
uint32_t
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_UINT64
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
uint64_t
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_FLOAT16
:
set_output
(
static_cast
<
double
>
(
static_cast
<
float
>
(
*
static_cast
<
const
half
*>
(
op
().
get_data_ptr
()))));
return
;
case
PLAIDML_DATA_FLOAT32
:
set_output
(
static_cast
<
double
>
(
*
static_cast
<
const
float
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_FLOAT64
:
set_output
(
static_cast
<
double
>
(
*
static_cast
<
const
double
*>
(
op
().
get_data_ptr
())));
return
;
default
:
break
;
}
}
auto
tensor
=
build
()
->
config
->
dev
->
allocate
(
to_plaidml
(
build
()
->
config
->
ctx
,
op
().
get_element_type
(),
op
().
get_shape
()));
{
vp
::
mapping
<
char
>
mp
=
tensor
.
map
(
vp
::
map_for_write
);
const
char
*
src
=
static_cast
<
const
char
*>
(
op
().
get_data_ptr
());
char
*
dest
=
mp
.
raw
();
std
::
copy
(
src
,
src
+
tensor
.
get_shape
().
buffer_size
(),
dest
);
}
set_output
(
tensor
);
}
// GetOutputElement pipes one of its N inputs to its output.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
GetOutputElement
>::
operator
()()
{
check_inputs_ge
(
op
().
get_n
()
+
1
);
check_outputs
(
1
);
set_output
(
op_input
(
op
().
get_n
()));
}
// Pad adds interior and exterior padding to a tensor.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Pad
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
auto
tensor
=
op_input
(
0
);
auto
value
=
op_input
(
1
);
// For padding, we construct two intermediate tensors: the first is the input tensor expanded by
// the requisite padding (with zeros in all padded locations), and the second is a boolean
// tensor expanded the same way, but with true at the source locations and false at the padded
// locations. We then combine these elementwise using a trinary condition, with the pad value
// being used everywhere the boolean intermediate is false.
// It's a little wasteful, but it expresses the logic correctly, and doesn't take long to run;
// the runtime is also free to optimize it through combining the intermediate contractions.
NGRAPH_DEBUG
<<
"Pad below: "
<<
op
().
get_padding_below
();
NGRAPH_DEBUG
<<
"Pad above: "
<<
op
().
get_padding_above
();
NGRAPH_DEBUG
<<
"Pad interior: "
<<
op
().
get_padding_interior
();
NGRAPH_DEBUG
<<
"Pad input dims: "
<<
op
().
get_input_shape
(
0
);
NGRAPH_DEBUG
<<
"Pad output dims: "
<<
op
().
get_shape
();
auto
dim_limit
=
op
().
get_shape
().
size
();
bool
any_zero_dims
=
false
;
for
(
auto
sz
:
op
().
get_input_shape
(
0
))
{
if
(
!
sz
)
{
any_zero_dims
=
true
;
break
;
}
}
auto
out_dsize
=
[
&
](
std
::
size_t
idx
)
{
std
::
ostringstream
s
;
std
::
size_t
total_pad
=
op
().
get_padding_below
().
at
(
idx
)
+
op
().
get_padding_above
().
at
(
idx
);
std
::
size_t
in_dsize
=
op
().
get_input_shape
(
0
).
at
(
idx
);
if
(
in_dsize
)
{
total_pad
+=
op
().
get_padding_interior
().
at
(
idx
)
*
(
in_dsize
-
1
);
}
if
(
!
any_zero_dims
)
{
s
<<
"DI"
<<
idx
+
1
;
if
(
total_pad
)
// Broadcast broadcasts a tensor to a wider shape.
template
<>
void
Impl
<
op
::
Broadcast
>::
operator
()()
{
s
<<
" + "
<<
total_pad
;
check_inputs
(
1
);
check_outputs
(
1
);
auto
in_dim_limit
=
op
().
get_inputs
()[
0
].
get_shape
().
size
();
auto
out_dim_limit
=
op
().
get_broadcast_shape
().
size
();
NGRAPH_DEBUG
<<
"Broadcast in_dim_limit: "
<<
in_dim_limit
<<
" out_dim_limit:"
<<
out_dim_limit
;
NGRAPH_DEBUG
<<
"Broadcast axes: "
<<
op
().
get_broadcast_axes
();
NGRAPH_DEBUG
<<
"Broadcast input shape: "
<<
op
().
get_input_shape
(
0
);
NGRAPH_DEBUG
<<
"Broadcast output shape: "
<<
op
().
get_broadcast_shape
();
auto
input_didx
=
in_dim_limit
;
std
::
vector
<
std
::
size_t
>
out_didxs
;
for
(
std
::
size_t
idx
=
0
;
idx
<
out_dim_limit
;
++
idx
)
{
if
(
!
op
().
get_broadcast_axes
().
count
(
idx
))
{
out_didxs
.
push_back
(
out_dim_limit
-
idx
-
1
);
}
}
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
}.
add_rdims
(
"D"
,
in_dim_limit
,
0
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_rindices
(
"o"
,
out_dim_limit
,
0
)
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
out_dim_limit
;
++
idx
)
{
if
(
op
().
get_broadcast_axes
().
count
(
idx
))
{
out
=
std
::
to_string
(
op
().
get_broadcast_shape
()[
idx
]);
}
else
{
out
=
"D"
+
std
::
to_string
(
--
input_didx
);
}
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
in_dim_limit
;
++
idx
)
{
out
=
"o"
+
std
::
to_string
(
out_didxs
[
idx
]);
}
})))
.
finalize
());
}
}
else
{
s
<<
total_pad
+
in_dsize
;
}
return
s
.
str
();
};
auto
out_didx
=
[
&
](
std
::
size_t
idx
)
{
std
::
ostringstream
s
;
auto
below
=
op
().
get_padding_below
().
at
(
idx
);
if
(
below
)
{
s
<<
below
<<
" + "
;
}
auto
interior
=
op
().
get_padding_interior
().
at
(
idx
)
+
1
;
if
(
interior
!=
1
)
{
s
<<
"(d"
<<
idx
+
1
<<
" * "
<<
interior
<<
")"
;
}
else
{
s
<<
"d"
<<
idx
+
1
;
}
return
s
.
str
();
};
auto
flag_constraints
=
[
&
](
std
::
size_t
idx
)
{
std
::
ostringstream
s
;
s
<<
"d"
<<
idx
+
1
<<
" < DI"
<<
idx
+
1
;
return
s
.
str
();
};
auto
f
=
start_tile_function
();
f
.
add
(
builder
::
Input
{
op_input
(
1
),
"V"
}).
add
(
builder
::
Output
{
"O"
});
// Constant fills in a tensor constant.
template
<>
void
Impl
<
op
::
Constant
>::
operator
()()
{
check_inputs
(
0
);
check_outputs
(
1
);
bool
output_to_result
=
false
;
for
(
const
std
::
shared_ptr
<
Node
>&
node
:
op
().
get_users
())
{
if
(
dynamic_cast
<
op
::
Result
*>
(
node
.
get
()))
{
output_to_result
=
true
;
break
;
}
}
if
(
!
op
().
get_shape
().
size
()
&&
!
output_to_result
)
{
switch
(
to_plaidml
(
op
().
get_element_type
()))
{
case
PLAIDML_DATA_BOOLEAN
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
char
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_INT8
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
int8_t
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_INT16
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
int16_t
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_INT32
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
int32_t
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_INT64
:
set_output
(
*
static_cast
<
const
std
::
int64_t
*>
(
op
().
get_data_ptr
()));
return
;
case
PLAIDML_DATA_UINT8
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
uint8_t
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_UINT16
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
uint16_t
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_UINT32
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
uint32_t
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_UINT64
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
uint64_t
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_FLOAT16
:
set_output
(
static_cast
<
double
>
(
static_cast
<
float
>
(
*
static_cast
<
const
half
*>
(
op
().
get_data_ptr
()))));
return
;
case
PLAIDML_DATA_FLOAT32
:
set_output
(
static_cast
<
double
>
(
*
static_cast
<
const
float
*>
(
op
().
get_data_ptr
())));
return
;
case
PLAIDML_DATA_FLOAT64
:
set_output
(
static_cast
<
double
>
(
*
static_cast
<
const
double
*>
(
op
().
get_data_ptr
())));
return
;
default
:
break
;
}
}
auto
tensor
=
build
()
->
config
->
dev
->
allocate
(
to_plaidml
(
build
()
->
config
->
ctx
,
op
().
get_element_type
(),
op
().
get_shape
()));
{
vp
::
mapping
<
char
>
mp
=
tensor
.
map
(
vp
::
map_for_write
);
const
char
*
src
=
static_cast
<
const
char
*>
(
op
().
get_data_ptr
());
char
*
dest
=
mp
.
raw
();
std
::
copy
(
src
,
src
+
tensor
.
get_shape
().
buffer_size
(),
dest
);
}
set_output
(
tensor
);
}
if
(
!
any_zero_dims
)
{
f
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
}.
add_dims
(
"DI"
,
1
,
dim_limit
+
1
))
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"P"
}
.
add_indices
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
out_didx
(
idx
);
}
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
out_dsize
(
idx
);
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
1
,
dim_limit
+
1
)))
.
add
(
builder
::
Elementwise
{
"T"
,
"1"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"F"
}
.
add_indices
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
out_didx
(
idx
);
}
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
out_dsize
(
idx
);
}
}))
.
set
(
builder
::
ContractionInput
{
"T"
})
.
add_constraints
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
flag_constraints
(
idx
);
}
}))
.
add
(
builder
::
Elementwise
{
"O"
,
"F ? P : V"
});
}
else
{
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"d"
,
0
,
dim_limit
)
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
out_dsize
(
idx
);
}
}))
.
set
(
builder
::
ContractionInput
{
"V"
}));
}
// GetOutputElement pipes one of its N inputs to its output.
template
<>
void
Impl
<
op
::
GetOutputElement
>::
operator
()()
{
check_inputs_ge
(
op
().
get_n
()
+
1
);
check_outputs
(
1
);
set_output
(
f
.
finalize
(
));
}
set_output
(
op_input
(
op
().
get_n
()
));
}
// Reshape reshapes an input tensor.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Reshape
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
// Pad adds interior and exterior padding to a tensor.
template
<>
void
Impl
<
op
::
Pad
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
auto
tensor
=
op_input
(
0
);
auto
value
=
op_input
(
1
);
// For padding, we construct two intermediate tensors: the first is the input tensor expanded by
// the requisite padding (with zeros in all padded locations), and the second is a boolean
// tensor expanded the same way, but with true at the source locations and false at the padded
// locations. We then combine these elementwise using a trinary condition, with the pad value
// being used everywhere the boolean intermediate is false.
// It's a little wasteful, but it expresses the logic correctly, and doesn't take long to run;
// the runtime is also free to optimize it through combining the intermediate contractions.
NGRAPH_DEBUG
<<
"Pad below: "
<<
op
().
get_padding_below
();
NGRAPH_DEBUG
<<
"Pad above: "
<<
op
().
get_padding_above
();
NGRAPH_DEBUG
<<
"Pad interior: "
<<
op
().
get_padding_interior
();
NGRAPH_DEBUG
<<
"Pad input dims: "
<<
op
().
get_input_shape
(
0
);
NGRAPH_DEBUG
<<
"Pad output dims: "
<<
op
().
get_shape
();
auto
dim_limit
=
op
().
get_shape
().
size
();
bool
any_zero_dims
=
false
;
for
(
auto
sz
:
op
().
get_input_shape
(
0
))
{
if
(
!
sz
)
{
any_zero_dims
=
true
;
break
;
}
}
auto
out_dsize
=
[
&
](
std
::
size_t
idx
)
{
std
::
ostringstream
s
;
std
::
size_t
total_pad
=
op
().
get_padding_below
().
at
(
idx
)
+
op
().
get_padding_above
().
at
(
idx
);
std
::
size_t
in_dsize
=
op
().
get_input_shape
(
0
).
at
(
idx
);
if
(
in_dsize
)
{
total_pad
+=
op
().
get_padding_interior
().
at
(
idx
)
*
(
in_dsize
-
1
);
}
if
(
!
any_zero_dims
)
{
s
<<
"DI"
<<
idx
+
1
;
if
(
total_pad
)
{
s
<<
" + "
<<
total_pad
;
}
}
else
{
s
<<
total_pad
+
in_dsize
;
}
return
s
.
str
();
};
auto
out_didx
=
[
&
](
std
::
size_t
idx
)
{
std
::
ostringstream
s
;
auto
below
=
op
().
get_padding_below
().
at
(
idx
);
if
(
below
)
{
s
<<
below
<<
" + "
;
}
auto
interior
=
op
().
get_padding_interior
().
at
(
idx
)
+
1
;
if
(
interior
!=
1
)
{
s
<<
"(d"
<<
idx
+
1
<<
" * "
<<
interior
<<
")"
;
}
else
{
s
<<
"d"
<<
idx
+
1
;
}
return
s
.
str
();
};
auto
flag_constraints
=
[
&
](
std
::
size_t
idx
)
{
std
::
ostringstream
s
;
s
<<
"d"
<<
idx
+
1
<<
" < DI"
<<
idx
+
1
;
return
s
.
str
();
};
auto
f
=
start_tile_function
();
f
.
add
(
builder
::
Input
{
op_input
(
1
),
"V"
}).
add
(
builder
::
Output
{
"O"
});
if
(
!
any_zero_dims
)
{
f
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
}.
add_dims
(
"DI"
,
1
,
dim_limit
+
1
))
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"P"
}
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
out_didx
(
idx
);
}
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
out_dsize
(
idx
);
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
1
,
dim_limit
+
1
)))
.
add
(
builder
::
Elementwise
{
"T"
,
"1"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"F"
}
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
out_didx
(
idx
);
}
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
out_dsize
(
idx
);
}
}))
.
set
(
builder
::
ContractionInput
{
"T"
})
.
add_constraints
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
flag_constraints
(
idx
);
}
}))
.
add
(
builder
::
Elementwise
{
"O"
,
"F ? P : V"
});
}
else
{
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"d"
,
0
,
dim_limit
)
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
out_dsize
(
idx
);
}
}))
.
set
(
builder
::
ContractionInput
{
"V"
}));
}
set_output
(
f
.
finalize
());
}
// The reshape operation doesn't just describe a new way of looking at an input tensor; it can
// optionally rearrange the elements of the input tensor.
// Reshape reshapes an input tensor.
template
<>
void
Impl
<
op
::
Reshape
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
// The reshape operation doesn't just describe a new way of looking at an input tensor; it can
// optionally rearrange the elements of the input tensor.
auto
src
=
op_input
(
0
);
auto
dim_limit
=
op
().
get_inputs
()[
0
].
get_shape
().
size
();
if
(
!
dim_limit
)
{
// This reshape is being used to create a tensor from a scalar. PlaidML's reshape()
// operation requires a tensor input (as of this writing), so instead of a reshape(), we'll
// just use a contraction to build the tensor.
auto
&
out_shape
=
op
().
get_shape
();
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
src
,
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"d"
,
0
,
out_shape
.
size
())
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
std
::
transform
(
out_shape
.
begin
(),
out_shape
.
end
(),
out
,
[](
std
::
size_t
sz
)
{
return
std
::
to_string
(
sz
);
});
}))
.
set
(
builder
::
ContractionInput
{
"I"
}))
.
finalize
());
return
;
}
std
::
size_t
dim_idx
=
0
;
auto
input_order
=
op
().
get_input_order
();
for
(
std
::
size_t
src_idx
:
op
().
get_input_order
())
{
if
(
src_idx
!=
dim_idx
++
)
{
// This reshape operation doesn't just describe a new way of looking at an input tensor;
// it's also rearranging the elements of the input tensor. This is pretty easy to
// handle with a contraction.
src
=
start_tile_function
()
.
add
(
builder
::
Input
{
src
,
"I"
}.
add_dims
(
"D"
,
1
,
dim_limit
+
1
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
"d"
+
std
::
to_string
(
input_order
[
idx
]
+
1
);
}
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
"D"
+
std
::
to_string
(
input_order
[
idx
]
+
1
);
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
1
,
dim_limit
+
1
)))
.
finalize
();
break
;
}
}
std
::
ostringstream
reshape_expr
;
reshape_expr
<<
"reshape(I"
;
for
(
std
::
size_t
dsize
:
op
().
get_output_shape
())
{
reshape_expr
<<
", "
<<
dsize
;
}
reshape_expr
<<
")"
;
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
src
,
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
(
"O"
,
reshape_expr
.
str
()))
.
finalize
());
}
auto
src
=
op_input
(
0
);
auto
dim_limit
=
op
().
get_inputs
()[
0
].
get_shape
().
size
();
// Select conditionally selects elements from input tensors.
template
<>
void
Impl
<
op
::
Select
>::
operator
()()
{
check_inputs
(
3
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"C"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"T"
})
.
add
(
builder
::
Input
{
op_input
(
2
),
"F"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"C ? T : F"
})
.
finalize
());
}
if
(
!
dim_limit
)
{
// This reshape is being used to create a tensor from a scalar. PlaidML's reshape()
// operation requires a tensor input (as of this writing), so instead of a reshape(), we'll
// just use a contraction to build the tensor.
auto
&
out_shape
=
op
().
get_shape
();
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
src
,
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"d"
,
0
,
out_shape
.
size
())
.
add_dims
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
std
::
transform
(
out_shape
.
begin
(),
out_shape
.
end
(),
out
,
[](
std
::
size_t
sz
)
{
return
std
::
to_string
(
sz
);
});
}))
.
set
(
builder
::
ContractionInput
{
"I"
}))
.
finalize
());
return
;
}
// Used by nGraph for bprop graph generation, no-op as a kernel
template
<>
void
Impl
<
op
::
StopGradient
>::
operator
()()
{
set_output
(
start_tile_function
()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"0"
})
.
finalize
());
}
std
::
size_t
dim_idx
=
0
;
auto
input_order
=
op
().
get_input_order
();
for
(
std
::
size_t
src_idx
:
op
().
get_input_order
())
{
if
(
src_idx
!=
dim_idx
++
)
{
// This reshape operation doesn't just describe a new way of looking at an input tensor;
// it's also rearranging the elements of the input tensor. This is pretty easy to
// handle with a contraction.
src
=
start_tile_function
()
.
add
(
builder
::
Input
{
src
,
"I"
}.
add_dims
(
"D"
,
1
,
dim_limit
+
1
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
"d"
+
std
::
to_string
(
input_order
[
idx
]
+
1
);
}
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
"D"
+
std
::
to_string
(
input_order
[
idx
]
+
1
);
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
1
,
dim_limit
+
1
)))
.
finalize
();
break
;
namespace
{
Impl
<
op
::
Broadcast
>::
Registration
register_broadcast
;
Impl
<
op
::
Constant
>::
Registration
register_constant
;
Impl
<
op
::
GetOutputElement
>::
Registration
register_get_output_element
;
Impl
<
op
::
Pad
>::
Registration
register_pad
;
Impl
<
op
::
Reshape
>::
Registration
register_reshape
;
Impl
<
op
::
Select
>::
Registration
register_select
;
Impl
<
op
::
StopGradient
>::
Registration
register_stop_gradient
;
}
}
}
std
::
ostringstream
reshape_expr
;
reshape_expr
<<
"reshape(I"
;
for
(
std
::
size_t
dsize
:
op
().
get_output_shape
())
{
reshape_expr
<<
", "
<<
dsize
;
}
reshape_expr
<<
")"
;
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
src
,
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
(
"O"
,
reshape_expr
.
str
()))
.
finalize
());
}
// Select conditionally selects elements from input tensors.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Select
>::
operator
()()
{
check_inputs
(
3
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"C"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"T"
})
.
add
(
builder
::
Input
{
op_input
(
2
),
"F"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"C ? T : F"
})
.
finalize
());
}
// Used by nGraph for bprop graph generation, no-op as a kernel
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
StopGradient
>::
operator
()()
{
set_output
(
start_tile_function
()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"0"
})
.
finalize
());
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Broadcast
>::
Registration
register_broadcast
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Constant
>::
Registration
register_constant
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
GetOutputElement
>::
Registration
register_get_output_element
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Pad
>::
Registration
register_pad
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Reshape
>::
Registration
register_reshape
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Select
>::
Registration
register_select
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
StopGradient
>::
Registration
register_stop_gradient
;
}
src/ngraph/runtime/plaidml/plaidml_ops_index_reduction.cpp
View file @
61df6725
...
...
@@ -36,111 +36,118 @@ namespace ngraph
void
build_index_reduction
(
const
char
*
agg_op
);
};
}
}
}
template
<
typename
O
>
void
ngraph
::
runtime
::
plaidml
::
IndexReductionImpl
<
O
>::
build_index_reduction
(
const
char
*
agg_op
)
{
this
->
check_inputs
(
1
);
this
->
check_outputs
(
1
);
template
<
typename
O
>
void
IndexReductionImpl
<
O
>::
build_index_reduction
(
const
char
*
agg_op
)
{
this
->
check_inputs
(
1
);
this
->
check_outputs
(
1
);
auto
dim_limit
=
this
->
op
().
get_inputs
()[
0
].
get_shape
().
size
();
auto
dim_limit
=
this
->
op
().
get_inputs
()[
0
].
get_shape
().
size
();
auto
reduction_axis_str
=
std
::
to_string
(
this
->
op
().
get_reduction_axis
());
auto
reduction_axis_str
=
std
::
to_string
(
this
->
op
().
get_reduction_axis
());
this
->
set_output
(
this
->
start_tile_function
()
.
add
(
builder
::
Input
{
this
->
op_input
(),
"I"
}.
add_dims
(
"D"
,
0
,
dim_limit
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
// Compute the maxes along the specified axis in the input
builder
::
UnaryContraction
{
agg_op
}
.
set
(
builder
::
ContractionOutput
{
"SelVal"
}
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
(
idx
==
this
->
op
().
get_reduction_axis
()
?
"rd"
:
"d"
)
+
std
::
to_string
(
idx
);
}
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
if
(
idx
==
this
->
op
().
get_reduction_axis
())
{
out
=
"1"
;
}
else
{
out
=
"D"
+
std
::
to_string
(
idx
);
}
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
0
,
dim_limit
)))
.
add
(
// Compare the input against the (broadcasted) max values, and select the indices
// where the max val occurs
builder
::
Elementwise
{
"SelValIdxs"
,
"I == SelVal ? index(I, "
+
reduction_axis_str
+
") : D"
+
reduction_axis_str
})
.
add
(
// Select the maximum index
builder
::
UnaryContraction
{
"<"
}
.
set
(
builder
::
ContractionOutput
{
"SelIdx"
}
.
add_indices
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
if
(
idx
!=
this
->
op
().
get_reduction_axis
())
{
out
=
"d"
+
std
::
to_string
(
idx
);
}
}
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
if
(
idx
!=
this
->
op
().
get_reduction_axis
())
{
out
=
"D"
+
std
::
to_string
(
idx
);
}
}
}))
.
set
(
builder
::
ContractionInput
{
"SelValIdxs"
}.
add_indices
(
"d"
,
0
,
dim_limit
)))
.
add
(
// Convert to the requested output element type (if any)
builder
::
Elementwise
{
"O"
,
tile_converter
(
"SelIdx"
,
this
->
op
().
get_index_element_type
())})
.
finalize
());
}
this
->
set_output
(
this
->
start_tile_function
()
.
add
(
builder
::
Input
{
this
->
op_input
(),
"I"
}.
add_dims
(
"D"
,
0
,
dim_limit
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
// Compute the maxes along the specified axis in the input
builder
::
UnaryContraction
{
agg_op
}
.
set
(
builder
::
ContractionOutput
{
"SelVal"
}
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
out
=
(
idx
==
this
->
op
().
get_reduction_axis
()
?
"rd"
:
"d"
)
+
std
::
to_string
(
idx
);
}
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
if
(
idx
==
this
->
op
().
get_reduction_axis
())
{
out
=
"1"
;
}
else
{
out
=
"D"
+
std
::
to_string
(
idx
);
}
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
0
,
dim_limit
)))
.
add
(
// Compare the input against the (broadcasted) max values, and select the indices
// where the max val occurs
builder
::
Elementwise
{
"SelValIdxs"
,
"I == SelVal ? index(I, "
+
reduction_axis_str
+
") : D"
+
reduction_axis_str
})
.
add
(
// Select the maximum index
builder
::
UnaryContraction
{
"<"
}
.
set
(
builder
::
ContractionOutput
{
"SelIdx"
}
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
if
(
idx
!=
this
->
op
().
get_reduction_axis
())
{
out
=
"d"
+
std
::
to_string
(
idx
);
}
}
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
if
(
idx
!=
this
->
op
().
get_reduction_axis
())
{
out
=
"D"
+
std
::
to_string
(
idx
);
}
}
}))
.
set
(
builder
::
ContractionInput
{
"SelValIdxs"
}.
add_indices
(
"d"
,
0
,
dim_limit
)))
.
add
(
// Convert to the requested output element type (if any)
builder
::
Elementwise
{
"O"
,
tile_converter
(
"SelIdx"
,
this
->
op
().
get_index_element_type
())})
.
finalize
());
}
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
ArgMax
>
{
using
Type
=
IndexReductionImpl
<
ngraph
::
op
::
ArgMax
>
;
};
template
<>
struct
ParentImpl
<
op
::
ArgMax
>
{
using
Type
=
IndexReductionImpl
<
op
::
ArgMax
>
;
};
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
ArgMin
>
{
using
Type
=
IndexReductionImpl
<
ngraph
::
op
::
ArgMin
>
;
};
template
<>
struct
ParentImpl
<
op
::
ArgMin
>
{
using
Type
=
IndexReductionImpl
<
op
::
ArgMin
>
;
};
// ArgMax computes the maximum index along a tensor axis.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ArgMax
>::
operator
()()
{
build_index_reduction
(
">"
);
}
// ArgMax computes the maximum index along a tensor axis.
template
<>
void
Impl
<
op
::
ArgMax
>::
operator
()()
{
build_index_reduction
(
">"
);
}
// ArgMin computes the minimum index along a tensor axis.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ArgMin
>::
operator
()()
{
build_index_reduction
(
"<"
);
}
// ArgMin computes the minimum index along a tensor axis.
template
<>
void
Impl
<
op
::
ArgMin
>::
operator
()()
{
build_index_reduction
(
"<"
);
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ArgMax
>::
Registration
register_argmax
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ArgMin
>::
Registration
register_argmin
;
namespace
{
Impl
<
op
::
ArgMax
>::
Registration
register_argmax
;
Impl
<
op
::
ArgMin
>::
Registration
register_argmin
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_io.cpp
View file @
61df6725
...
...
@@ -20,35 +20,44 @@
namespace
vp
=
vertexai
::
plaidml
;
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Parameter
>::
operator
()()
namespace
ngraph
{
check_inputs
(
0
);
check_outputs
(
1
);
vp
::
placeholder
ph
{
build
()
->
io_dim_override
?
build
()
->
io_dim_override_count
:
op
().
get_output_shape
(
0
).
size
()};
std
::
string
name
=
std
::
string
{
"I"
}
+
std
::
to_string
(
build
()
->
input_names
.
size
());
descriptor
::
Tensor
*
tv
=
op
().
get_output_tensor_ptr
().
get
();
build
()
->
bindings
.
emplace
(
tv
,
TensorInfo
{
ph
,
TensorContents
::
DATA
});
build
()
->
composer
.
input
(
name
,
ph
);
build
()
->
input_names
.
emplace
(
tv
,
std
::
move
(
name
));
}
namespace
runtime
{
namespace
plaidml
{
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder.
template
<>
void
Impl
<
op
::
Parameter
>::
operator
()()
{
check_inputs
(
0
);
check_outputs
(
1
);
vp
::
placeholder
ph
{
build
()
->
io_dim_override
?
build
()
->
io_dim_override_count
:
op
().
get_output_shape
(
0
).
size
()};
std
::
string
name
=
std
::
string
{
"I"
}
+
std
::
to_string
(
build
()
->
input_names
.
size
());
descriptor
::
Tensor
*
tv
=
op
().
get_output_tensor_ptr
().
get
();
build
()
->
bindings
.
emplace
(
tv
,
TensorInfo
{
ph
,
TensorContents
::
DATA
});
build
()
->
composer
.
input
(
name
,
ph
);
build
()
->
input_names
.
emplace
(
tv
,
std
::
move
(
name
));
}
// Result binds a PlaidML variable to a composed function output.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Result
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
std
::
string
name
=
std
::
string
{
"O"
}
+
std
::
to_string
(
build
()
->
output_names
.
size
());
descriptor
::
Tensor
*
tv
=
op
().
get_output_tensor_ptr
().
get
();
build
()
->
composer
.
output
(
name
,
op_input
());
build
()
->
output_names
.
emplace
(
tv
,
std
::
move
(
name
));
}
// Result binds a PlaidML variable to a composed function output.
template
<>
void
Impl
<
op
::
Result
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
std
::
string
name
=
std
::
string
{
"O"
}
+
std
::
to_string
(
build
()
->
output_names
.
size
());
descriptor
::
Tensor
*
tv
=
op
().
get_output_tensor_ptr
().
get
();
build
()
->
composer
.
output
(
name
,
op_input
());
build
()
->
output_names
.
emplace
(
tv
,
std
::
move
(
name
));
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Parameter
>::
Registration
register_parameter
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Result
>::
Registration
register_result
;
namespace
{
Impl
<
op
::
Parameter
>::
Registration
register_parameter
;
Impl
<
op
::
Result
>::
Registration
register_result
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_local_response_norm.cpp
View file @
61df6725
...
...
@@ -17,40 +17,52 @@
#include "ngraph/op/lrn.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// LRN implements Local Response Normalization
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
LRN
>::
operator
()()
namespace
ngraph
{
check_inputs
(
1
);
check_outputs
(
1
);
auto
dim_limit
=
op
().
get_inputs
()[
0
].
get_shape
().
size
();
auto
rank
=
dim_limit
-
2
;
auto
distance
=
op
().
get_nsize
()
/
2
;
std
::
ostringstream
div_expr
;
div_expr
<<
"I / pow("
<<
op
().
get_bias
()
<<
".0 + (("
<<
op
().
get_alpha
()
<<
".0 / "
<<
op
().
get_nsize
()
<<
".0) * S), "
<<
op
().
get_beta
()
<<
".0)"
;
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
}.
add_dims
({
"N"
,
"C"
}).
add_dims
(
"D"
,
0
,
rank
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"ISQ"
,
"I * I"
})
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
set
(
builder
::
ContractionOutput
{
"S"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"d"
,
0
,
rank
)
.
add_dims
({
"N"
,
"C"
})
.
add_dims
(
"D"
,
0
,
rank
))
.
set
(
builder
::
ContractionInput
{
"ISQ"
}
.
add_indices
({
"n"
,
"c + z - "
+
std
::
to_string
(
distance
)})
.
add_indices
(
"d"
,
0
,
rank
))
.
add_constraints
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
out
=
"z < "
+
std
::
to_string
(
op
().
get_nsize
());
}))
.
add
(
builder
::
Elementwise
{
"O"
,
div_expr
.
str
()})
.
finalize
());
}
namespace
runtime
{
namespace
plaidml
{
// LRN implements Local Response Normalization
template
<>
void
Impl
<
op
::
LRN
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
auto
dim_limit
=
op
().
get_inputs
()[
0
].
get_shape
().
size
();
auto
rank
=
dim_limit
-
2
;
auto
distance
=
op
().
get_nsize
()
/
2
;
std
::
ostringstream
div_expr
;
div_expr
<<
"I / pow("
<<
op
().
get_bias
()
<<
".0 + (("
<<
op
().
get_alpha
()
<<
".0 / "
<<
op
().
get_nsize
()
<<
".0) * S), "
<<
op
().
get_beta
()
<<
".0)"
;
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
}
.
add_dims
({
"N"
,
"C"
})
.
add_dims
(
"D"
,
0
,
rank
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"ISQ"
,
"I * I"
})
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
set
(
builder
::
ContractionOutput
{
"S"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"d"
,
0
,
rank
)
.
add_dims
({
"N"
,
"C"
})
.
add_dims
(
"D"
,
0
,
rank
))
.
set
(
builder
::
ContractionInput
{
"ISQ"
}
.
add_indices
({
"n"
,
"c + z - "
+
std
::
to_string
(
distance
)})
.
add_indices
(
"d"
,
0
,
rank
))
.
add_constraints
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
out
=
"z < "
+
std
::
to_string
(
op
().
get_nsize
());
}))
.
add
(
builder
::
Elementwise
{
"O"
,
div_expr
.
str
()})
.
finalize
());
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
LRN
>::
Registration
register_local_response_norm
;
namespace
{
Impl
<
op
::
LRN
>::
Registration
register_local_response_norm
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_logical.cpp
View file @
61df6725
...
...
@@ -19,53 +19,62 @@
#include "ngraph/op/or.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// And performs a simple elementwise logical and.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
And
>::
operator
()()
namespace
ngraph
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
,
TensorContents
::
LOGICAL
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
,
TensorContents
::
LOGICAL
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A ? B : A"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
namespace
runtime
{
namespace
plaidml
{
// And performs a simple elementwise logical and.
template
<>
void
Impl
<
op
::
And
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
,
TensorContents
::
LOGICAL
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
,
TensorContents
::
LOGICAL
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A ? B : A"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
// Not performs a simple elementwise logical not.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Not
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
,
TensorContents
::
LOGICAL
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"cmp_eq(I, 0)"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
// Not performs a simple elementwise logical not.
template
<>
void
Impl
<
op
::
Not
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
,
TensorContents
::
LOGICAL
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"cmp_eq(I, 0)"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
// Or performs a simple elementwise logical or.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Or
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
,
TensorContents
::
LOGICAL
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
,
TensorContents
::
LOGICAL
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A ? A : B"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
// Or performs a simple elementwise logical or.
template
<>
void
Impl
<
op
::
Or
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
,
TensorContents
::
LOGICAL
),
"A"
})
.
add
(
builder
::
Input
{
op_input
(
1
,
TensorContents
::
LOGICAL
),
"B"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A ? A : B"
})
.
finalize
(),
TensorContents
::
LOGICAL
);
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
And
>::
Registration
register_and
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Not
>::
Registration
register_not
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Or
>::
Registration
register_or
;
namespace
{
Impl
<
op
::
And
>::
Registration
register_and
;
Impl
<
op
::
Not
>::
Registration
register_not
;
Impl
<
op
::
Or
>::
Registration
register_or
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_one_hot.cpp
View file @
61df6725
...
...
@@ -20,80 +20,93 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp"
// OneHot performs one-hot encoding along the requested axis.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
OneHot
>::
operator
()()
namespace
ngraph
{
check_inputs
(
1
);
check_outputs
(
1
);
namespace
runtime
{
namespace
plaidml
{
// OneHot performs one-hot encoding along the requested axis.
template
<>
void
Impl
<
op
::
OneHot
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
// Here's what's going on to implement OneHot:
//
// * We reshape the input tensor to add a size=1 dimension where we want the one-hot axis to be,
//
// * We create an index tensor that's size=1 on every dimension except the one-hot dimension,
//
// * We perform an elementwise conditional across them to assign the one-hot values.
//
// The broadcast rules will expand the index tensor on all non-one-hot dimensions to match the
// input, and will expand the input tensor on the one-hot dimension to match the index.
//
// In theory, it'd be pretty easy to implement all this with purely elementwise operations. The
// current definition of index() requires an input tensor of the index() output shape, and it's
// a little tricky to fix that, so we generate a zero tensor of the correct shape using a
// contraction. TODO: Optimize out the zero tensor contraction.
// Here's what's going on to implement OneHot:
//
// * We reshape the input tensor to add a size=1 dimension where we want the one-hot axis to be,
//
// * We create an index tensor that's size=1 on every dimension except the one-hot dimension,
//
// * We perform an elementwise conditional across them to assign the one-hot values.
//
// The broadcast rules will expand the index tensor on all non-one-hot dimensions to match the
// input, and will expand the input tensor on the one-hot dimension to match the index.
//
// In theory, it'd be pretty easy to implement all this with purely elementwise operations. The
// current definition of index() requires an input tensor of the index() output shape, and it's
// a little tricky to fix that, so we generate a zero tensor of the correct shape using a
// contraction. TODO: Optimize out the zero tensor contraction.
const
auto
&
in_shape
=
op
().
get_inputs
()[
0
].
get_shape
();
const
auto
&
out_shape
=
op
().
get_shape
();
const
auto
&
in_shape
=
op
().
get_inputs
()[
0
].
get_shape
();
const
auto
&
out_shape
=
op
().
get_shape
();
std
::
ostringstream
in_reshape
;
for
(
std
::
size_t
idx
=
0
;
idx
<
out_shape
.
size
();
++
idx
)
{
if
(
idx
)
{
in_reshape
<<
", "
;
}
if
(
idx
==
op
().
get_one_hot_axis
())
{
in_reshape
<<
1
;
}
else
{
in_reshape
<<
out_shape
[
idx
];
}
}
std
::
ostringstream
in_reshape
;
for
(
std
::
size_t
idx
=
0
;
idx
<
out_shape
.
size
();
++
idx
)
{
if
(
idx
)
{
in_reshape
<<
", "
;
}
if
(
idx
==
op
().
get_one_hot_axis
())
{
in_reshape
<<
1
;
}
else
{
in_reshape
<<
out_shape
[
idx
];
}
}
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
}.
add_dims
(
"D"
,
0
,
in_shape
.
size
()))
.
add
(
builder
::
Input
{
static_cast
<
std
::
int64_t
>
(
0
),
"Zero"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"ZS"
}
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
out_shape
.
size
();
++
idx
)
{
if
(
idx
==
op
().
get_one_hot_axis
())
{
out
=
std
::
to_string
(
out_shape
[
idx
]);
}
else
{
out
=
"1"
;
}
}
})
.
add_indices
(
"d"
,
0
,
out_shape
.
size
()))
.
set
(
builder
::
ContractionInput
{
"Zero"
}))
.
add
(
builder
::
Elementwise
{
"Idx"
,
"index(ZS, "
+
std
::
to_string
(
op
().
get_one_hot_axis
())
+
")"
})
.
add
(
builder
::
Elementwise
{
"IS"
,
"reshape(I, "
+
in_reshape
.
str
()
+
")"
})
.
add
(
builder
::
Elementwise
{
"OV"
,
"IS == Idx ? 1 : 0"
})
.
add
(
builder
::
Elementwise
{
"O"
,
tile_converter
(
"OV"
,
op
().
get_element_type
())})
.
finalize
());
}
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
}.
add_dims
(
"D"
,
0
,
in_shape
.
size
()))
.
add
(
builder
::
Input
{
static_cast
<
std
::
int64_t
>
(
0
),
"Zero"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"ZS"
}
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
out_shape
.
size
();
++
idx
)
{
if
(
idx
==
op
().
get_one_hot_axis
())
{
out
=
std
::
to_string
(
out_shape
[
idx
]);
}
else
{
out
=
"1"
;
}
}
})
.
add_indices
(
"d"
,
0
,
out_shape
.
size
()))
.
set
(
builder
::
ContractionInput
{
"Zero"
}))
.
add
(
builder
::
Elementwise
{
"Idx"
,
"index(ZS, "
+
std
::
to_string
(
op
().
get_one_hot_axis
())
+
")"
})
.
add
(
builder
::
Elementwise
{
"IS"
,
"reshape(I, "
+
in_reshape
.
str
()
+
")"
})
.
add
(
builder
::
Elementwise
{
"OV"
,
"IS == Idx ? 1 : 0"
})
.
add
(
builder
::
Elementwise
{
"O"
,
tile_converter
(
"OV"
,
op
().
get_element_type
())})
.
finalize
());
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
OneHot
>::
Registration
register_one_hot
;
namespace
{
Impl
<
op
::
OneHot
>::
Registration
register_one_hot
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_pool.cpp
View file @
61df6725
...
...
@@ -20,293 +20,302 @@
#include "ngraph/runtime/plaidml/plaidml_convpool_formatter.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// AvgPool implements a batch average pooling operation.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
AvgPool
>::
operator
()()
namespace
ngraph
{
check_inputs
(
1
);
check_outputs
(
1
);
auto
src_dims
=
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
;
const
auto
&
padding_above
=
op
().
get_padding_above
();
const
auto
&
padding_below
=
op
().
get_padding_below
();
const
auto
&
window_shape
=
op
().
get_window_shape
();
const
auto
&
strides
=
op
().
get_window_movement_strides
();
const
auto
&
include_padding
=
op
().
get_include_padding_in_avg_computation
();
ngraph
::
CoordinateDiff
pad_above
;
ngraph
::
CoordinateDiff
pad_below
;
for
(
const
auto
&
pad
:
padding_above
)
namespace
runtime
{
pad_above
.
push_back
(
pad
);
}
for
(
const
auto
&
pad
:
padding_below
)
{
pad_below
.
push_back
(
pad
);
}
// Overpadding occurs iff any padding value is >= its corresponding window shape. If this
// happens, we need to conditionally set the padded values to the operation default.
bool
overpad
=
false
;
for
(
std
::
size_t
idx
=
0
;
idx
<
src_dims
;
++
idx
)
{
auto
shape
=
window_shape
[
idx
];
if
(
shape
<=
padding_below
[
idx
]
||
shape
<=
padding_above
[
idx
])
{
overpad
=
true
;
break
;
}
}
if
(
overpad
)
{
throw
std
::
runtime_error
{
"The PlaidML nGraph backend does not support over-padded AvgPool "
"operations"
};
}
ConvPoolFormatter
cpf
(
src_dims
,
pad_below
,
pad_above
,
strides
,
window_shape
,
ConvPoolFormatter
::
OpType
::
AvgPool
,
ConvPoolFormatter
::
DerivType
::
None
);
vertexai
::
plaidml
::
variable
one
{
static_cast
<
std
::
int64_t
>
(
1
)};
auto
f
=
start_tile_function
();
f
.
add
(
cpf
.
I_in_header
(
op_input
()))
.
add
(
builder
::
Input
{
one
,
"One"
})
.
add
(
cpf
.
O_out_header
())
.
add
(
cpf
.
Broadcast_Ones
());
if
(
include_padding
)
{
f
.
add
(
builder
::
Elementwise
{
"Count"
,
std
::
to_string
(
shape_size
(
window_shape
))});
}
else
{
f
.
add
(
cpf
.
Count
());
}
f
.
add
(
cpf
.
PoolContraction
()).
add
(
builder
::
Elementwise
{
"O"
,
"S / Count"
});
set_output
(
f
.
finalize
());
}
// MaxPool implements a batch max pooling operation.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
MaxPool
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
auto
src_dims
=
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
;
const
auto
&
padding_above
=
op
().
get_padding_above
();
const
auto
&
padding_below
=
op
().
get_padding_below
();
const
auto
&
window_shape
=
op
().
get_window_shape
();
const
auto
&
strides
=
op
().
get_window_movement_strides
();
ngraph
::
CoordinateDiff
pad_above
;
ngraph
::
CoordinateDiff
pad_below
;
for
(
const
auto
&
pad
:
padding_above
)
{
pad_above
.
push_back
(
pad
);
}
for
(
const
auto
&
pad
:
padding_below
)
{
pad_below
.
push_back
(
pad
);
}
NGRAPH_DEBUG
<<
"MaxPool padding_below: "
<<
padding_below
;
NGRAPH_DEBUG
<<
"MaxPool padding_above: "
<<
padding_above
;
NGRAPH_DEBUG
<<
"MaxPool window_shape: "
<<
window_shape
;
NGRAPH_DEBUG
<<
"MaxPool window_movement_strides: "
<<
strides
;
// Overpadding occurs iff any padding value is >= its corresponding window shape. If this
// happens, we need to conditionally set the padded values to the operation default.
bool
overpad
=
false
;
for
(
std
::
size_t
idx
=
0
;
idx
<
src_dims
;
++
idx
)
{
auto
shape
=
window_shape
[
idx
];
if
(
shape
<=
padding_below
[
idx
]
||
shape
<=
padding_above
[
idx
])
{
overpad
=
true
;
break
;
}
}
if
(
overpad
)
{
throw
std
::
runtime_error
{
"The PlaidML nGraph backend does not support over-padded MaxPool "
"operations"
};
}
ConvPoolFormatter
cpf
(
src_dims
,
pad_below
,
pad_above
,
strides
,
window_shape
,
ConvPoolFormatter
::
OpType
::
MaxPool
,
ConvPoolFormatter
::
DerivType
::
None
);
set_output
(
start_tile_function
()
.
add
(
cpf
.
I_in_header
(
op_input
()))
.
add
(
cpf
.
O_out_header
())
.
add
(
cpf
.
PoolContraction
())
.
finalize
());
}
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
AvgPoolBackprop
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
auto
src_dims
=
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
;
const
auto
&
forward_arg_shape
=
op
().
get_forward_arg_shape
();
const
auto
&
padding_above
=
op
().
get_padding_above
();
const
auto
&
padding_below
=
op
().
get_padding_below
();
const
auto
&
window_shape
=
op
().
get_window_shape
();
const
auto
&
strides
=
op
().
get_window_movement_strides
();
const
auto
&
include_padding
=
op
().
get_include_padding_in_avg_computation
();
if
(
include_padding
)
{
throw
std
::
runtime_error
(
"Include padding in average not yet implemented in PlaidML"
);
}
ngraph
::
CoordinateDiff
pad_above
;
ngraph
::
CoordinateDiff
pad_below
;
for
(
const
auto
&
pad
:
padding_above
)
{
pad_above
.
push_back
(
pad
);
}
for
(
const
auto
&
pad
:
padding_below
)
{
pad_below
.
push_back
(
pad
);
}
// Overpadding occurs iff any padding value is >= its corresponding window shape. If this
// happens, we need to conditionally set the padded values to the operation default.
bool
overpad
=
false
;
for
(
std
::
size_t
idx
=
0
;
idx
<
src_dims
;
++
idx
)
{
auto
shape
=
window_shape
[
idx
];
if
(
shape
<=
padding_below
[
idx
]
||
shape
<=
padding_above
[
idx
])
namespace
plaidml
{
overpad
=
true
;
break
;
// AvgPool implements a batch average pooling operation.
template
<>
void
Impl
<
op
::
AvgPool
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
auto
src_dims
=
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
;
const
auto
&
padding_above
=
op
().
get_padding_above
();
const
auto
&
padding_below
=
op
().
get_padding_below
();
const
auto
&
window_shape
=
op
().
get_window_shape
();
const
auto
&
strides
=
op
().
get_window_movement_strides
();
const
auto
&
include_padding
=
op
().
get_include_padding_in_avg_computation
();
ngraph
::
CoordinateDiff
pad_above
;
ngraph
::
CoordinateDiff
pad_below
;
for
(
const
auto
&
pad
:
padding_above
)
{
pad_above
.
push_back
(
pad
);
}
for
(
const
auto
&
pad
:
padding_below
)
{
pad_below
.
push_back
(
pad
);
}
// Overpadding occurs iff any padding value is >= its corresponding window shape. If this
// happens, we need to conditionally set the padded values to the operation default.
bool
overpad
=
false
;
for
(
std
::
size_t
idx
=
0
;
idx
<
src_dims
;
++
idx
)
{
auto
shape
=
window_shape
[
idx
];
if
(
shape
<=
padding_below
[
idx
]
||
shape
<=
padding_above
[
idx
])
{
overpad
=
true
;
break
;
}
}
if
(
overpad
)
{
throw
std
::
runtime_error
{
"The PlaidML nGraph backend does not support over-padded AvgPool "
"operations"
};
}
ConvPoolFormatter
cpf
(
src_dims
,
pad_below
,
pad_above
,
strides
,
window_shape
,
ConvPoolFormatter
::
OpType
::
AvgPool
,
ConvPoolFormatter
::
DerivType
::
None
);
vertexai
::
plaidml
::
variable
one
{
static_cast
<
std
::
int64_t
>
(
1
)};
auto
f
=
start_tile_function
();
f
.
add
(
cpf
.
I_in_header
(
op_input
()))
.
add
(
builder
::
Input
{
one
,
"One"
})
.
add
(
cpf
.
O_out_header
())
.
add
(
cpf
.
Broadcast_Ones
());
if
(
include_padding
)
{
f
.
add
(
builder
::
Elementwise
{
"Count"
,
std
::
to_string
(
shape_size
(
window_shape
))});
}
else
{
f
.
add
(
cpf
.
Count
());
}
f
.
add
(
cpf
.
PoolContraction
()).
add
(
builder
::
Elementwise
{
"O"
,
"S / Count"
});
set_output
(
f
.
finalize
());
}
// MaxPool implements a batch max pooling operation.
template
<>
void
Impl
<
op
::
MaxPool
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
auto
src_dims
=
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
;
const
auto
&
padding_above
=
op
().
get_padding_above
();
const
auto
&
padding_below
=
op
().
get_padding_below
();
const
auto
&
window_shape
=
op
().
get_window_shape
();
const
auto
&
strides
=
op
().
get_window_movement_strides
();
ngraph
::
CoordinateDiff
pad_above
;
ngraph
::
CoordinateDiff
pad_below
;
for
(
const
auto
&
pad
:
padding_above
)
{
pad_above
.
push_back
(
pad
);
}
for
(
const
auto
&
pad
:
padding_below
)
{
pad_below
.
push_back
(
pad
);
}
NGRAPH_DEBUG
<<
"MaxPool padding_below: "
<<
padding_below
;
NGRAPH_DEBUG
<<
"MaxPool padding_above: "
<<
padding_above
;
NGRAPH_DEBUG
<<
"MaxPool window_shape: "
<<
window_shape
;
NGRAPH_DEBUG
<<
"MaxPool window_movement_strides: "
<<
strides
;
// Overpadding occurs iff any padding value is >= its corresponding window shape. If this
// happens, we need to conditionally set the padded values to the operation default.
bool
overpad
=
false
;
for
(
std
::
size_t
idx
=
0
;
idx
<
src_dims
;
++
idx
)
{
auto
shape
=
window_shape
[
idx
];
if
(
shape
<=
padding_below
[
idx
]
||
shape
<=
padding_above
[
idx
])
{
overpad
=
true
;
break
;
}
}
if
(
overpad
)
{
throw
std
::
runtime_error
{
"The PlaidML nGraph backend does not support over-padded MaxPool "
"operations"
};
}
ConvPoolFormatter
cpf
(
src_dims
,
pad_below
,
pad_above
,
strides
,
window_shape
,
ConvPoolFormatter
::
OpType
::
MaxPool
,
ConvPoolFormatter
::
DerivType
::
None
);
set_output
(
start_tile_function
()
.
add
(
cpf
.
I_in_header
(
op_input
()))
.
add
(
cpf
.
O_out_header
())
.
add
(
cpf
.
PoolContraction
())
.
finalize
());
}
template
<>
void
Impl
<
op
::
AvgPoolBackprop
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
auto
src_dims
=
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
;
const
auto
&
forward_arg_shape
=
op
().
get_forward_arg_shape
();
const
auto
&
padding_above
=
op
().
get_padding_above
();
const
auto
&
padding_below
=
op
().
get_padding_below
();
const
auto
&
window_shape
=
op
().
get_window_shape
();
const
auto
&
strides
=
op
().
get_window_movement_strides
();
const
auto
&
include_padding
=
op
().
get_include_padding_in_avg_computation
();
if
(
include_padding
)
{
throw
std
::
runtime_error
(
"Include padding in average not yet implemented in PlaidML"
);
}
ngraph
::
CoordinateDiff
pad_above
;
ngraph
::
CoordinateDiff
pad_below
;
for
(
const
auto
&
pad
:
padding_above
)
{
pad_above
.
push_back
(
pad
);
}
for
(
const
auto
&
pad
:
padding_below
)
{
pad_below
.
push_back
(
pad
);
}
// Overpadding occurs iff any padding value is >= its corresponding window shape. If this
// happens, we need to conditionally set the padded values to the operation default.
bool
overpad
=
false
;
for
(
std
::
size_t
idx
=
0
;
idx
<
src_dims
;
++
idx
)
{
auto
shape
=
window_shape
[
idx
];
if
(
shape
<=
padding_below
[
idx
]
||
shape
<=
padding_above
[
idx
])
{
overpad
=
true
;
break
;
}
}
if
(
overpad
)
{
throw
std
::
runtime_error
{
"The PlaidML nGraph backend does not support over-padded AvgPool "
"operations"
};
}
ConvPoolFormatter
cpf
(
src_dims
,
pad_below
,
pad_above
,
strides
,
window_shape
,
ConvPoolFormatter
::
OpType
::
AvgPool
,
ConvPoolFormatter
::
DerivType
::
Data
);
const
auto
&
incoming_deriv
=
op_input
();
vertexai
::
plaidml
::
variable
one
{
static_cast
<
std
::
int64_t
>
(
1
)};
auto
ret
=
start_tile_function
();
ret
.
add
(
cpf
.
O_in_header
(
incoming_deriv
))
.
add
(
builder
::
Input
{
one
,
"One"
})
.
add
(
builder
::
Output
{
"DI"
});
for
(
int
i
=
2
;
i
<
forward_arg_shape
.
size
();
++
i
)
{
std
::
ostringstream
s
;
s
<<
"XI"
<<
i
-
2
;
ret
.
add
(
builder
::
Input
{
static_cast
<
std
::
int64_t
>
(
forward_arg_shape
[
i
]),
s
.
str
()});
}
set_output
(
ret
.
add
(
cpf
.
Broadcast_Ones
())
.
add
(
cpf
.
Count
())
.
add
(
builder
::
Elementwise
{
"S"
,
"DO / Count"
})
.
add
(
cpf
.
PoolContraction
())
.
finalize
());
}
template
<>
void
Impl
<
op
::
MaxPoolBackprop
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
auto
src_dims
=
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
;
const
auto
&
padding_above
=
op
().
get_padding_above
();
const
auto
&
padding_below
=
op
().
get_padding_below
();
const
auto
&
window_shape
=
op
().
get_window_shape
();
const
auto
&
strides
=
op
().
get_window_movement_strides
();
ngraph
::
CoordinateDiff
pad_above
;
ngraph
::
CoordinateDiff
pad_below
;
for
(
const
auto
&
pad
:
padding_above
)
{
pad_above
.
push_back
(
pad
);
}
for
(
const
auto
&
pad
:
padding_below
)
{
pad_below
.
push_back
(
pad
);
}
// Overpadding occurs iff any padding value is >= its corresponding window shape. If this
// happens, we need to conditionally set the padded values to the operation default.
bool
overpad
=
false
;
for
(
std
::
size_t
idx
=
0
;
idx
<
src_dims
;
++
idx
)
{
auto
shape
=
window_shape
[
idx
];
if
(
shape
<=
padding_below
[
idx
]
||
shape
<=
padding_above
[
idx
])
{
overpad
=
true
;
break
;
}
}
if
(
overpad
)
{
throw
std
::
runtime_error
{
"The PlaidML nGraph backend does not support over-padded MaxPool "
"operations"
};
}
ConvPoolFormatter
cpf
(
src_dims
,
pad_below
,
pad_above
,
strides
,
window_shape
,
ConvPoolFormatter
::
OpType
::
MaxPool
,
ConvPoolFormatter
::
DerivType
::
Data
);
const
auto
&
input
=
op_input
(
0
);
const
auto
&
incoming_deriv
=
op_input
(
1
);
set_output
(
start_tile_function
()
.
add
(
cpf
.
I_in_header
(
input
))
.
add
(
cpf
.
O_in_header
(
incoming_deriv
))
.
add
(
builder
::
Output
{
"DI"
})
.
add
(
cpf
.
PoolContraction
())
.
add
(
cpf
.
PoolDerivContraction
())
.
finalize
());
}
namespace
{
Impl
<
op
::
AvgPool
>::
Registration
register_avg_pool
;
Impl
<
op
::
MaxPool
>::
Registration
register_max_pool
;
Impl
<
op
::
AvgPoolBackprop
>::
Registration
register_avg_pool_backprop
;
Impl
<
op
::
MaxPoolBackprop
>::
Registration
register_max_pool_backprop
;
}
}
}
if
(
overpad
)
{
throw
std
::
runtime_error
{
"The PlaidML nGraph backend does not support over-padded AvgPool "
"operations"
};
}
ConvPoolFormatter
cpf
(
src_dims
,
pad_below
,
pad_above
,
strides
,
window_shape
,
ConvPoolFormatter
::
OpType
::
AvgPool
,
ConvPoolFormatter
::
DerivType
::
Data
);
const
auto
&
incoming_deriv
=
op_input
();
vertexai
::
plaidml
::
variable
one
{
static_cast
<
std
::
int64_t
>
(
1
)};
auto
ret
=
start_tile_function
();
ret
.
add
(
cpf
.
O_in_header
(
incoming_deriv
))
.
add
(
builder
::
Input
{
one
,
"One"
})
.
add
(
builder
::
Output
{
"DI"
});
for
(
int
i
=
2
;
i
<
forward_arg_shape
.
size
();
++
i
)
{
std
::
ostringstream
s
;
s
<<
"XI"
<<
i
-
2
;
ret
.
add
(
builder
::
Input
{
static_cast
<
std
::
int64_t
>
(
forward_arg_shape
[
i
]),
s
.
str
()});
}
set_output
(
ret
.
add
(
cpf
.
Broadcast_Ones
())
.
add
(
cpf
.
Count
())
.
add
(
builder
::
Elementwise
{
"S"
,
"DO / Count"
})
.
add
(
cpf
.
PoolContraction
())
.
finalize
());
}
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
MaxPoolBackprop
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
auto
src_dims
=
op
().
get_inputs
()[
0
].
get_shape
().
size
()
-
2
;
const
auto
&
padding_above
=
op
().
get_padding_above
();
const
auto
&
padding_below
=
op
().
get_padding_below
();
const
auto
&
window_shape
=
op
().
get_window_shape
();
const
auto
&
strides
=
op
().
get_window_movement_strides
();
ngraph
::
CoordinateDiff
pad_above
;
ngraph
::
CoordinateDiff
pad_below
;
for
(
const
auto
&
pad
:
padding_above
)
{
pad_above
.
push_back
(
pad
);
}
for
(
const
auto
&
pad
:
padding_below
)
{
pad_below
.
push_back
(
pad
);
}
// Overpadding occurs iff any padding value is >= its corresponding window shape. If this
// happens, we need to conditionally set the padded values to the operation default.
bool
overpad
=
false
;
for
(
std
::
size_t
idx
=
0
;
idx
<
src_dims
;
++
idx
)
{
auto
shape
=
window_shape
[
idx
];
if
(
shape
<=
padding_below
[
idx
]
||
shape
<=
padding_above
[
idx
])
{
overpad
=
true
;
break
;
}
}
if
(
overpad
)
{
throw
std
::
runtime_error
{
"The PlaidML nGraph backend does not support over-padded MaxPool "
"operations"
};
}
ConvPoolFormatter
cpf
(
src_dims
,
pad_below
,
pad_above
,
strides
,
window_shape
,
ConvPoolFormatter
::
OpType
::
MaxPool
,
ConvPoolFormatter
::
DerivType
::
Data
);
const
auto
&
input
=
op_input
(
0
);
const
auto
&
incoming_deriv
=
op_input
(
1
);
set_output
(
start_tile_function
()
.
add
(
cpf
.
I_in_header
(
input
))
.
add
(
cpf
.
O_in_header
(
incoming_deriv
))
.
add
(
builder
::
Output
{
"DI"
})
.
add
(
cpf
.
PoolContraction
())
.
add
(
cpf
.
PoolDerivContraction
())
.
finalize
());
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
AvgPool
>::
Registration
register_avg_pool
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
MaxPool
>::
Registration
register_max_pool
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
AvgPoolBackprop
>::
Registration
register_avg_pool_backprop
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
MaxPoolBackprop
>::
Registration
register_max_pool_backprop
;
}
src/ngraph/runtime/plaidml/plaidml_ops_reduce.cpp
View file @
61df6725
...
...
@@ -42,222 +42,236 @@ namespace ngraph
void
build_reduction
(
const
char
*
agg_op
);
};
}
}
}
template
<
typename
O
>
void
ngraph
::
runtime
::
plaidml
::
ReductionImpl
<
O
>::
build_reduction
(
const
char
*
agg_op
)
{
this
->
check_inputs
(
1
);
this
->
check_outputs
(
1
);
template
<
typename
O
>
void
ReductionImpl
<
O
>::
build_reduction
(
const
char
*
agg_op
)
{
this
->
check_inputs
(
1
);
this
->
check_outputs
(
1
);
auto
in_shape
=
this
->
op
().
get_input_shape
(
0
);
auto
in_dim_limit
=
in_shape
.
size
();
auto
in_shape
=
this
->
op
().
get_input_shape
(
0
);
auto
in_dim_limit
=
in_shape
.
size
();
std
::
vector
<
std
::
size_t
>
out_idxs
;
for
(
std
::
size_t
in_idx
=
0
;
in_idx
<
in_dim_limit
;
++
in_idx
)
{
if
(
!
this
->
op
().
get_reduction_axes
().
count
(
in_idx
))
{
out_idxs
.
push_back
(
in_idx
);
}
}
std
::
vector
<
std
::
size_t
>
out_idxs
;
for
(
std
::
size_t
in_idx
=
0
;
in_idx
<
in_dim_limit
;
++
in_idx
)
{
if
(
!
this
->
op
().
get_reduction_axes
().
count
(
in_idx
))
{
out_idxs
.
push_back
(
in_idx
);
}
}
this
->
set_output
(
this
->
start_tile_function
()
.
add
(
builder
::
Output
{
"O"
})
this
->
set_output
(
this
->
start_tile_function
()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Input
{
this
->
op_input
(
0
),
"I"
}.
add_dims
(
"D"
,
1
,
in_dim_limit
+
1
))
.
add
(
builder
::
UnaryContraction
{
agg_op
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
out_idxs
.
size
();
++
idx
)
{
out
=
"d"
+
std
::
to_string
(
out_idxs
[
idx
]
+
1
);
}
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
out_idxs
.
size
();
++
idx
)
{
out
=
"D"
+
std
::
to_string
(
out_idxs
[
idx
]
+
1
);
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
1
,
in_dim_limit
+
1
)))
.
finalize
());
}
.
add
(
builder
::
Input
{
this
->
op_input
(
0
),
"I"
}.
add_dims
(
"D"
,
1
,
in_dim_limit
+
1
))
.
add
(
builder
::
UnaryContraction
{
agg_op
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
out_idxs
.
size
();
++
idx
)
{
out
=
"d"
+
std
::
to_string
(
out_idxs
[
idx
]
+
1
);
}
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
out_idxs
.
size
();
++
idx
)
{
out
=
"D"
+
std
::
to_string
(
out_idxs
[
idx
]
+
1
);
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
1
,
in_dim_limit
+
1
)))
.
finalize
());
}
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
Max
>
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ReductionImpl
<
ngraph
::
op
::
Max
>
;
};
template
<>
struct
ParentImpl
<
op
::
Max
>
{
using
Type
=
ReductionImpl
<
op
::
Max
>
;
};
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
Min
>
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ReductionImpl
<
ngraph
::
op
::
Min
>
;
};
template
<>
struct
ParentImpl
<
op
::
Min
>
{
using
Type
=
ReductionImpl
<
op
::
Min
>
;
};
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
Product
>
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ReductionImpl
<
ngraph
::
op
::
Product
>
;
};
template
<>
struct
ParentImpl
<
op
::
Product
>
{
using
Type
=
ReductionImpl
<
op
::
Product
>
;
};
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
Reduce
>
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ReductionImpl
<
ngraph
::
op
::
Reduce
>
;
};
template
<>
struct
ParentImpl
<
op
::
Reduce
>
{
using
Type
=
ReductionImpl
<
op
::
Reduce
>
;
};
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
Sum
>
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ReductionImpl
<
ngraph
::
op
::
Sum
>
;
};
template
<>
struct
ParentImpl
<
op
::
Sum
>
{
using
Type
=
ReductionImpl
<
op
::
Sum
>
;
};
// Max reduces a tensor, taking the maximum along the specified axes.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Max
>::
operator
()()
{
build_reduction
(
">"
);
}
// Max reduces a tensor, taking the maximum along the specified axes.
template
<>
void
Impl
<
op
::
Max
>::
operator
()()
{
build_reduction
(
">"
);
}
// Min reduces a tensor, taking the minimum along the specified axes.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Min
>::
operator
()()
{
build_reduction
(
"<"
);
}
// Min reduces a tensor, taking the minimum along the specified axes.
template
<>
void
Impl
<
op
::
Min
>::
operator
()()
{
build_reduction
(
"<"
);
}
// Min reduces a tensor, taking the product along the specified axes.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Product
>::
operator
()()
{
build_reduction
(
"*"
);
}
// Min reduces a tensor, taking the product along the specified axes.
template
<>
void
Impl
<
op
::
Product
>::
operator
()()
{
build_reduction
(
"*"
);
}
// Reduce reduces a tensor with an arbitrary user-supplied reduction operation.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Reduce
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
// Reduce reduces a tensor with an arbitrary user-supplied reduction operation.
template
<>
void
Impl
<
op
::
Reduce
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
// TODO: Special case known-easy reductions.
// TODO: Special case known-easy reductions.
// To support arbitrary reduction operations, we take advantage of the fact that in nGraph, we
// have concrete dimension sizes. We start with the initial tensor (argument 1), construct N
// slices of tensor 0 (where N == the product of the sizes of the axes to reduce), and
// repeatedly apply the supplied aggregation function to them.
//
// This is somewhat inefficient, but works.
const
Shape
&
input_shape
=
op
().
get_input_shape
(
0
);
auto
dim_limit
=
input_shape
.
size
();
Shape
reduction_shape
;
for
(
std
::
size_t
axis_idx
=
0
;
axis_idx
<
input_shape
.
size
();
++
axis_idx
)
{
if
(
op
().
get_reduction_axes
().
count
(
axis_idx
))
{
reduction_shape
.
emplace_back
(
input_shape
[
axis_idx
]);
}
}
std
::
size_t
agg_dim_limit
=
dim_limit
-
reduction_shape
.
size
();
// To support arbitrary reduction operations, we take advantage of the fact that in nGraph, we
// have concrete dimension sizes. We start with the initial tensor (argument 1), construct N
// slices of tensor 0 (where N == the product of the sizes of the axes to reduce), and
// repeatedly apply the supplied aggregation function to them.
//
// This is somewhat inefficient, but works.
const
Shape
&
input_shape
=
op
().
get_input_shape
(
0
);
auto
dim_limit
=
input_shape
.
size
();
Shape
reduction_shape
;
for
(
std
::
size_t
axis_idx
=
0
;
axis_idx
<
input_shape
.
size
();
++
axis_idx
)
{
if
(
op
().
get_reduction_axes
().
count
(
axis_idx
))
{
reduction_shape
.
emplace_back
(
input_shape
[
axis_idx
]);
}
}
std
::
size_t
agg_dim_limit
=
dim_limit
-
reduction_shape
.
size
();
vp
::
function
agg_fn
;
{
Build
b
;
b
.
io_dim_override
=
true
;
b
.
io_dim_override_count
=
agg_dim_limit
;
build
()
->
compiler
->
build
(
op
().
get_functions
()[
0
],
&
b
);
agg_fn
=
b
.
composer
;
}
vp
::
function
agg_fn
;
{
Build
b
;
b
.
io_dim_override
=
true
;
b
.
io_dim_override_count
=
agg_dim_limit
;
build
()
->
compiler
->
build
(
op
().
get_functions
()[
0
],
&
b
);
agg_fn
=
b
.
composer
;
}
vp
::
variable
input
=
op_input
(
0
);
vp
::
variable
input
=
op_input
(
0
);
// Note that we need to explicitly broadcast the 0-dimensional base result to match the
// aggregation dimension count.
vp
::
variable
result
=
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
1
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"d"
,
0
,
agg_dim_limit
)
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
agg_dim_limit
;
++
idx
)
{
out
=
"1"
;
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}))
.
finalize
();
// Note that we need to explicitly broadcast the 0-dimensional base result to match the
// aggregation dimension count.
vp
::
variable
result
=
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
1
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"d"
,
0
,
agg_dim_limit
)
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
agg_dim_limit
;
++
idx
)
{
out
=
"1"
;
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}))
.
finalize
();
CoordinateTransform
reduction_coords
{
reduction_shape
};
for
(
const
Coordinate
&
coordinate
:
reduction_coords
)
{
result
=
agg_fn
(
result
,
start_tile_function
()
.
add
(
builder
::
Input
{
input
,
"I"
}.
add_dims
(
"D"
,
0
,
dim_limit
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
input_shape
.
size
();
++
idx
)
{
if
(
!
op
().
get_reduction_axes
().
count
(
idx
))
{
out
=
"d"
+
std
::
to_string
(
idx
);
}
}
})
.
add_dims
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
input_shape
.
size
();
++
idx
)
{
if
(
!
op
().
get_reduction_axes
().
count
(
idx
))
{
out
=
"D"
+
std
::
to_string
(
idx
);
}
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
input_shape
.
size
();
++
idx
)
{
std
::
size_t
cidx
=
0
;
if
(
!
op
().
get_reduction_axes
().
count
(
idx
))
{
out
=
"d"
+
std
::
to_string
(
idx
);
}
else
{
out
=
std
::
to_string
(
coordinate
[
cidx
++
]);
}
}
})))
.
finalize
());
}
CoordinateTransform
reduction_coords
{
reduction_shape
};
for
(
const
Coordinate
&
coordinate
:
reduction_coords
)
{
result
=
agg_fn
(
result
,
start_tile_function
()
.
add
(
builder
::
Input
{
input
,
"I"
}.
add_dims
(
"D"
,
0
,
dim_limit
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
input_shape
.
size
();
++
idx
)
{
if
(
!
op
().
get_reduction_axes
().
count
(
idx
))
{
out
=
"d"
+
std
::
to_string
(
idx
);
}
}
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
input_shape
.
size
();
++
idx
)
{
if
(
!
op
().
get_reduction_axes
().
count
(
idx
))
{
out
=
"D"
+
std
::
to_string
(
idx
);
}
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
input_shape
.
size
();
++
idx
)
{
std
::
size_t
cidx
=
0
;
if
(
!
op
().
get_reduction_axes
().
count
(
idx
))
{
out
=
"d"
+
std
::
to_string
(
idx
);
}
else
{
out
=
std
::
to_string
(
coordinate
[
cidx
++
]);
}
}
})))
.
finalize
());
}
set_output
(
result
);
}
set_output
(
result
);
}
// Sum reduces a tensor, summing the specified axes.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sum
>::
operator
()()
{
build_reduction
(
"+"
);
}
// Sum reduces a tensor, summing the specified axes.
template
<>
void
Impl
<
op
::
Sum
>::
operator
()()
{
build_reduction
(
"+"
);
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Max
>::
Registration
register_max
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Min
>::
Registration
register_min
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Product
>::
Registration
register_product
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Reduce
>::
Registration
register_reduce
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sum
>::
Registration
register_sum
;
namespace
{
Impl
<
op
::
Max
>::
Registration
register_max
;
Impl
<
op
::
Min
>::
Registration
register_min
;
Impl
<
op
::
Product
>::
Registration
register_product
;
Impl
<
op
::
Reduce
>::
Registration
register_reduce
;
Impl
<
op
::
Sum
>::
Registration
register_sum
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_replace_slice.cpp
View file @
61df6725
...
...
@@ -19,74 +19,87 @@
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// ReplaceSlice replaces part of a tensor with another tensor.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ReplaceSlice
>::
operator
()()
namespace
ngraph
{
check_inputs
(
2
);
check_outputs
(
1
);
namespace
runtime
{
namespace
plaidml
{
// ReplaceSlice replaces part of a tensor with another tensor.
template
<>
void
Impl
<
op
::
ReplaceSlice
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
// For ReplaceSlice:
//
// * Pad the second tensor to match the first (same-size dimensions and offset according to the
// * lower bounds of the replacement, with the desired stridings)
//
// * Generate a boolean tensor of the same shape as the first, where true == "Do the
// * replacement".
//
// * Use a trinary to do the replacement.
// For ReplaceSlice:
//
// * Pad the second tensor to match the first (same-size dimensions and offset according to the
// * lower bounds of the replacement, with the desired stridings)
//
// * Generate a boolean tensor of the same shape as the first, where true == "Do the
// * replacement".
//
// * Use a trinary to do the replacement.
const
auto
&
shape
=
op
().
get_shape
();
const
auto
&
shape
=
op
().
get_shape
();
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"L"
}.
add_dims
(
"D"
,
0
,
shape
.
size
()))
.
add
(
builder
::
Input
{
op_input
(
1
),
"S"
}.
add_dims
(
"SD"
,
0
,
shape
.
size
()))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_dims
(
"D"
,
0
,
shape
.
size
())
.
add_indices
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
shape
.
size
();
++
idx
)
{
auto
stride
=
op
().
get_strides
()[
idx
];
auto
lower_bound
=
op
().
get_lower_bounds
()[
idx
];
std
::
ostringstream
didx
;
if
((
stride
!=
1
)
&&
lower_bound
)
{
didx
<<
"("
;
}
didx
<<
"d"
<<
idx
;
if
(
stride
!=
1
)
{
didx
<<
"*"
<<
stride
;
}
if
((
stride
!=
1
)
&&
lower_bound
)
{
didx
<<
")"
;
}
if
(
lower_bound
)
{
didx
<<
"+"
<<
lower_bound
;
}
out
=
didx
.
str
();
}
}))
.
set
(
builder
::
ContractionInput
{
"S"
}.
add_indices
(
"d"
,
0
,
shape
.
size
()))
.
add_constraints
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
shape
.
size
();
++
idx
)
{
out
=
"d"
+
std
::
to_string
(
idx
)
+
" < "
+
std
::
to_string
(
op
().
get_upper_bounds
()[
idx
]
-
op
().
get_lower_bounds
()[
idx
]);
}
})
.
set_default
(
"L"
))
.
finalize
());
}
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"L"
}.
add_dims
(
"D"
,
0
,
shape
.
size
()))
.
add
(
builder
::
Input
{
op_input
(
1
),
"S"
}.
add_dims
(
"SD"
,
0
,
shape
.
size
()))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_dims
(
"D"
,
0
,
shape
.
size
())
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
shape
.
size
();
++
idx
)
{
auto
stride
=
op
().
get_strides
()[
idx
];
auto
lower_bound
=
op
().
get_lower_bounds
()[
idx
];
std
::
ostringstream
didx
;
if
((
stride
!=
1
)
&&
lower_bound
)
{
didx
<<
"("
;
}
didx
<<
"d"
<<
idx
;
if
(
stride
!=
1
)
{
didx
<<
"*"
<<
stride
;
}
if
((
stride
!=
1
)
&&
lower_bound
)
{
didx
<<
")"
;
}
if
(
lower_bound
)
{
didx
<<
"+"
<<
lower_bound
;
}
out
=
didx
.
str
();
}
}))
.
set
(
builder
::
ContractionInput
{
"S"
}.
add_indices
(
"d"
,
0
,
shape
.
size
()))
.
add_constraints
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
shape
.
size
();
++
idx
)
{
out
=
"d"
+
std
::
to_string
(
idx
)
+
" < "
+
std
::
to_string
(
op
().
get_upper_bounds
()[
idx
]
-
op
().
get_lower_bounds
()[
idx
]);
}
})
.
set_default
(
"L"
))
.
finalize
());
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ReplaceSlice
>::
Registration
register_replace_slice
;
namespace
{
Impl
<
op
::
ReplaceSlice
>::
Registration
register_replace_slice
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_reverse.cpp
View file @
61df6725
...
...
@@ -19,41 +19,50 @@
#include "ngraph/op/reverse.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Reverse reverses the selected axes within a tensor.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Reverse
>::
operator
()()
namespace
ngraph
{
check_inputs
(
1
);
check_outputs
(
1
);
namespace
runtime
{
namespace
plaidml
{
// Reverse reverses the selected axes within a tensor.
template
<>
void
Impl
<
op
::
Reverse
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
const
auto
&
shape
=
op
().
get_shape
();
const
auto
&
shape
=
op
().
get_shape
();
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
}.
add_dims
(
"D"
,
0
,
shape
.
size
()))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"d"
,
0
,
shape
.
size
())
.
add_dims
(
"D"
,
0
,
shape
.
size
()))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
shape
.
size
();
++
idx
)
{
auto
sidx
=
std
::
to_string
(
idx
);
if
(
op
().
get_reversed_axes
().
count
(
idx
))
{
out
=
"D"
+
sidx
+
"-d"
+
sidx
+
"-1"
;
}
else
{
out
=
"d"
+
sidx
;
}
}
})))
.
finalize
());
}
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
}.
add_dims
(
"D"
,
0
,
shape
.
size
()))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"d"
,
0
,
shape
.
size
())
.
add_dims
(
"D"
,
0
,
shape
.
size
()))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
([
&
]
(
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
shape
.
size
();
++
idx
)
{
auto
sidx
=
std
::
to_string
(
idx
);
if
(
op
().
get_reversed_axes
().
count
(
idx
))
{
out
=
"D"
+
sidx
+
"-d"
+
sidx
+
"-1"
;
}
else
{
out
=
"d"
+
sidx
;
}
}
})))
.
finalize
());
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Reverse
>::
Registration
register_reverse
;
namespace
{
Impl
<
op
::
Reverse
>::
Registration
register_reverse
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_slice.cpp
View file @
61df6725
...
...
@@ -18,87 +18,100 @@
#include "ngraph/op/slice.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Slice takes a sub-slice of a tensor.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Slice
>::
operator
()()
namespace
ngraph
{
check_inputs
(
1
);
check_outputs
(
1
);
NGRAPH_DEBUG
<<
"Slice: low: "
<<
op
().
get_lower_bounds
();
NGRAPH_DEBUG
<<
"Slice high: "
<<
op
().
get_upper_bounds
();
NGRAPH_DEBUG
<<
"Slice stride: "
<<
op
().
get_strides
();
const
auto
&
shape
=
op
().
get_inputs
()[
0
].
get_shape
();
auto
dim_limit
=
shape
.
size
();
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
}.
add_dims
(
"ID"
,
0
,
dim_limit
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"od"
,
0
,
dim_limit
)
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
std
::
ostringstream
s
;
std
::
size_t
stride
=
op
().
get_strides
()[
idx
];
std
::
ptrdiff_t
trim_count
=
op
().
get_lower_bounds
()[
idx
]
+
(
shape
[
idx
]
-
op
().
get_upper_bounds
()[
idx
])
+
1
-
stride
;
if
((
stride
!=
1
)
&&
trim_count
)
{
s
<<
"("
;
}
s
<<
"ID"
<<
idx
;
if
(
0
<
trim_count
)
{
s
<<
" - "
<<
trim_count
;
}
if
(
trim_count
<
0
)
{
s
<<
" + "
<<
-
trim_count
;
}
if
((
stride
!=
1
)
&&
trim_count
)
{
s
<<
")"
;
}
if
(
stride
!=
1
)
{
s
<<
" / "
<<
stride
;
}
out
=
s
.
str
();
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
std
::
ostringstream
s
;
std
::
size_t
stride
=
op
().
get_strides
()[
idx
];
std
::
size_t
offset
=
op
().
get_lower_bounds
()[
idx
];
if
((
stride
!=
1
)
&&
offset
)
{
s
<<
"("
;
}
s
<<
"od"
<<
idx
;
if
(
stride
!=
1
)
{
s
<<
" * "
<<
stride
;
}
if
((
stride
!=
1
)
&&
offset
)
{
s
<<
")"
;
}
if
(
offset
)
{
s
<<
" + "
<<
offset
;
}
out
=
s
.
str
();
}
})))
.
finalize
());
}
namespace
runtime
{
namespace
plaidml
{
// Slice takes a sub-slice of a tensor.
template
<>
void
Impl
<
op
::
Slice
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
NGRAPH_DEBUG
<<
"Slice: low: "
<<
op
().
get_lower_bounds
();
NGRAPH_DEBUG
<<
"Slice high: "
<<
op
().
get_upper_bounds
();
NGRAPH_DEBUG
<<
"Slice stride: "
<<
op
().
get_strides
();
const
auto
&
shape
=
op
().
get_inputs
()[
0
].
get_shape
();
auto
dim_limit
=
shape
.
size
();
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
}.
add_dims
(
"ID"
,
0
,
dim_limit
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"od"
,
0
,
dim_limit
)
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
std
::
ostringstream
s
;
std
::
size_t
stride
=
op
().
get_strides
()[
idx
];
std
::
ptrdiff_t
trim_count
=
op
().
get_lower_bounds
()[
idx
]
+
(
shape
[
idx
]
-
op
().
get_upper_bounds
()[
idx
])
+
1
-
stride
;
if
((
stride
!=
1
)
&&
trim_count
)
{
s
<<
"("
;
}
s
<<
"ID"
<<
idx
;
if
(
0
<
trim_count
)
{
s
<<
" - "
<<
trim_count
;
}
if
(
trim_count
<
0
)
{
s
<<
" + "
<<
-
trim_count
;
}
if
((
stride
!=
1
)
&&
trim_count
)
{
s
<<
")"
;
}
if
(
stride
!=
1
)
{
s
<<
" / "
<<
stride
;
}
out
=
s
.
str
();
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
std
::
ostringstream
s
;
std
::
size_t
stride
=
op
().
get_strides
()[
idx
];
std
::
size_t
offset
=
op
().
get_lower_bounds
()[
idx
];
if
((
stride
!=
1
)
&&
offset
)
{
s
<<
"("
;
}
s
<<
"od"
<<
idx
;
if
(
stride
!=
1
)
{
s
<<
" * "
<<
stride
;
}
if
((
stride
!=
1
)
&&
offset
)
{
s
<<
")"
;
}
if
(
offset
)
{
s
<<
" + "
<<
offset
;
}
out
=
s
.
str
();
}
})))
.
finalize
());
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Slice
>::
Registration
register_slice
;
namespace
{
Impl
<
op
::
Slice
>::
Registration
register_slice
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_softmax.cpp
View file @
61df6725
...
...
@@ -19,149 +19,162 @@
#include "ngraph/op/softmax.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Softmax implements a standard ML softmax operation.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Softmax
>::
operator
()()
namespace
ngraph
{
check_inputs
(
1
);
check_outputs
(
1
);
namespace
runtime
{
namespace
plaidml
{
// Softmax implements a standard ML softmax operation.
template
<>
void
Impl
<
op
::
Softmax
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
const
auto
&
shape
=
op
().
get_inputs
()[
0
].
get_shape
();
auto
dim_limit
=
shape
.
size
();
const
auto
&
shape
=
op
().
get_inputs
()[
0
].
get_shape
();
auto
dim_limit
=
shape
.
size
();
auto
f
=
start_tile_function
();
f
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
}.
add_dims
(
"D"
,
0
,
dim_limit
)).
add
(
builder
::
Output
{
"O"
});
auto
f
=
start_tile_function
();
f
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
}.
add_dims
(
"D"
,
0
,
dim_limit
))
.
add
(
builder
::
Output
{
"O"
});
bool
reorder_needed
=
false
;
bool
saw_element
=
false
;
auto
groups
=
1
;
auto
elements
=
1
;
std
::
vector
<
std
::
size_t
>
group_idxs
;
std
::
vector
<
std
::
size_t
>
element_idxs
;
bool
reorder_needed
=
false
;
bool
saw_element
=
false
;
auto
groups
=
1
;
auto
elements
=
1
;
std
::
vector
<
std
::
size_t
>
group_idxs
;
std
::
vector
<
std
::
size_t
>
element_idxs
;
for
(
auto
didx
=
0
;
didx
<
shape
.
size
();
++
didx
)
{
if
(
op
().
get_axes
().
count
(
didx
))
{
elements
*=
shape
[
didx
];
element_idxs
.
push_back
(
didx
);
saw_element
=
true
;
}
else
{
groups
*=
shape
[
didx
];
group_idxs
.
push_back
(
didx
);
if
(
saw_element
)
{
reorder_needed
=
true
;
}
}
}
for
(
auto
didx
=
0
;
didx
<
shape
.
size
();
++
didx
)
{
if
(
op
().
get_axes
().
count
(
didx
))
{
elements
*=
shape
[
didx
];
element_idxs
.
push_back
(
didx
);
saw_element
=
true
;
}
else
{
groups
*=
shape
[
didx
];
group_idxs
.
push_back
(
didx
);
if
(
saw_element
)
{
reorder_needed
=
true
;
}
}
}
const
char
*
input
=
"I"
;
const
char
*
output
=
"O"
;
const
char
*
reshape_output
=
output
;
bool
reshape_needed
=
dim_limit
!=
2
;
const
char
*
input
=
"I"
;
const
char
*
output
=
"O"
;
const
char
*
reshape_output
=
output
;
bool
reshape_needed
=
dim_limit
!=
2
;
if
(
!
reorder_needed
)
{
reshape_needed
|=
shape
[
0
]
!=
groups
;
}
else
{
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"RI"
}
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
:
group_idxs
)
{
out
=
"D"
+
std
::
to_string
(
idx
);
}
for
(
auto
idx
:
element_idxs
)
{
out
=
"D"
+
std
::
to_string
(
idx
);
}
})
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
:
group_idxs
)
{
out
=
"d"
+
std
::
to_string
(
idx
);
}
for
(
auto
idx
:
element_idxs
)
{
out
=
"d"
+
std
::
to_string
(
idx
);
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
0
,
dim_limit
)));
input
=
"RI"
;
output
=
"RO"
;
if
(
group_idxs
.
size
())
{
reshape_needed
|=
shape
[
group_idxs
[
0
]]
!=
groups
;
}
else
{
reshape_needed
|=
shape
[
element_idxs
[
0
]]
!=
groups
;
}
}
if
(
!
reorder_needed
)
{
reshape_needed
|=
shape
[
0
]
!=
groups
;
}
else
{
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"RI"
}
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
:
group_idxs
)
{
out
=
"D"
+
std
::
to_string
(
idx
);
}
for
(
auto
idx
:
element_idxs
)
{
out
=
"D"
+
std
::
to_string
(
idx
);
}
})
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
:
group_idxs
)
{
out
=
"d"
+
std
::
to_string
(
idx
);
}
for
(
auto
idx
:
element_idxs
)
{
out
=
"d"
+
std
::
to_string
(
idx
);
}
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
0
,
dim_limit
)));
input
=
"RI"
;
output
=
"RO"
;
if
(
group_idxs
.
size
())
{
reshape_needed
|=
shape
[
group_idxs
[
0
]]
!=
groups
;
}
else
{
reshape_needed
|=
shape
[
element_idxs
[
0
]]
!=
groups
;
}
}
if
(
reshape_needed
)
{
std
::
ostringstream
reshape
;
reshape
<<
"reshape("
<<
input
<<
", "
<<
groups
<<
", "
<<
elements
<<
")"
;
f
.
add
(
builder
::
Elementwise
{
"GI"
,
reshape
.
str
()});
input
=
"GI"
;
reshape_output
=
output
;
output
=
"GO"
;
}
if
(
reshape_needed
)
{
std
::
ostringstream
reshape
;
reshape
<<
"reshape("
<<
input
<<
", "
<<
groups
<<
", "
<<
elements
<<
")"
;
f
.
add
(
builder
::
Elementwise
{
"GI"
,
reshape
.
str
()});
input
=
"GI"
;
reshape_output
=
output
;
output
=
"GO"
;
}
{
// Take the softmax.
std
::
ostringstream
softmax
;
softmax
<<
"builtin_softmax("
<<
input
<<
", "
<<
groups
<<
", "
<<
elements
<<
")"
;
f
.
add
(
builder
::
Elementwise
{
output
,
softmax
.
str
()});
}
{
// Take the softmax.
std
::
ostringstream
softmax
;
softmax
<<
"builtin_softmax("
<<
input
<<
", "
<<
groups
<<
", "
<<
elements
<<
")"
;
f
.
add
(
builder
::
Elementwise
{
output
,
softmax
.
str
()});
}
if
(
reshape_needed
)
{
// Unbundle the axes.
std
::
ostringstream
reshape
;
reshape
<<
"reshape(GO"
;
for
(
auto
didx
:
group_idxs
)
{
reshape
<<
", "
<<
shape
[
didx
];
}
for
(
auto
didx
:
element_idxs
)
{
reshape
<<
", "
<<
shape
[
didx
];
}
reshape
<<
")"
;
f
.
add
(
builder
::
Elementwise
{
reshape_output
,
reshape
.
str
()});
output
=
reshape_output
;
}
if
(
reshape_needed
)
{
// Unbundle the axes.
std
::
ostringstream
reshape
;
reshape
<<
"reshape(GO"
;
for
(
auto
didx
:
group_idxs
)
{
reshape
<<
", "
<<
shape
[
didx
];
}
for
(
auto
didx
:
element_idxs
)
{
reshape
<<
", "
<<
shape
[
didx
];
}
reshape
<<
")"
;
f
.
add
(
builder
::
Elementwise
{
reshape_output
,
reshape
.
str
()});
output
=
reshape_output
;
}
if
(
reorder_needed
)
{
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_dims
(
"D"
,
0
,
dim_limit
)
.
add_indices
(
"d"
,
0
,
dim_limit
))
.
set
(
builder
::
ContractionInput
{
output
}.
add_indices
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
:
group_idxs
)
{
out
=
"d"
+
std
::
to_string
(
idx
);
}
for
(
auto
idx
:
element_idxs
)
{
out
=
"d"
+
std
::
to_string
(
idx
);
}
})));
}
if
(
reorder_needed
)
{
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_dims
(
"D"
,
0
,
dim_limit
)
.
add_indices
(
"d"
,
0
,
dim_limit
))
.
set
(
builder
::
ContractionInput
{
output
}.
add_indices
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
:
group_idxs
)
{
out
=
"d"
+
std
::
to_string
(
idx
);
}
for
(
auto
idx
:
element_idxs
)
{
out
=
"d"
+
std
::
to_string
(
idx
);
}
})));
}
set_output
(
f
.
finalize
());
}
set_output
(
f
.
finalize
());
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Softmax
>::
Registration
register_softmax
;
namespace
{
Impl
<
op
::
Softmax
>::
Registration
register_softmax
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_transcendental.cpp
View file @
61df6725
...
...
@@ -29,189 +29,198 @@
#include "ngraph/op/tanh.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// acos performs a simple elementwise arccos function.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Acos
>::
operator
()()
namespace
ngraph
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"acos(I)"
})
.
finalize
());
}
namespace
runtime
{
namespace
plaidml
{
// acos performs a simple elementwise arccos function.
template
<>
void
Impl
<
op
::
Acos
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"acos(I)"
})
.
finalize
());
}
// asin performs a simple elementwise arcsin function.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Asin
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"asin(I)"
})
.
finalize
());
}
// asin performs a simple elementwise arcsin function.
template
<>
void
Impl
<
op
::
Asin
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"asin(I)"
})
.
finalize
());
}
// atan performs a simple elementwise arctan function.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Atan
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"atan(I)"
})
.
finalize
());
}
// atan performs a simple elementwise arctan function.
template
<>
void
Impl
<
op
::
Atan
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"atan(I)"
})
.
finalize
());
}
// cos performs a simple elementwise cos function.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Cos
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"cos(I)"
})
.
finalize
());
}
// cos performs a simple elementwise cos function.
template
<>
void
Impl
<
op
::
Cos
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"cos(I)"
})
.
finalize
());
}
// cosh performs a simple elementwise hyperbolic cos function.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Cosh
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"cosh(I)"
})
.
finalize
());
}
// cosh performs a simple elementwise hyperbolic cos function.
template
<>
void
Impl
<
op
::
Cosh
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"cosh(I)"
})
.
finalize
());
}
// exp performs a simple elementwise natural exponential function.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Exp
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"exp(I)"
})
.
finalize
());
}
// exp performs a simple elementwise natural exponential function.
template
<>
void
Impl
<
op
::
Exp
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"exp(I)"
})
.
finalize
());
}
// log performs a simple elementwise natural logarithm function.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Log
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"log(I)"
})
.
finalize
());
}
// log performs a simple elementwise natural logarithm function.
template
<>
void
Impl
<
op
::
Log
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"log(I)"
})
.
finalize
());
}
// power performs a simple elementwise power function.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Power
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"E"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"pow(I, E)"
})
.
finalize
());
}
// power performs a simple elementwise power function.
template
<>
void
Impl
<
op
::
Power
>::
operator
()()
{
check_inputs
(
2
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"E"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"pow(I, E)"
})
.
finalize
());
}
// sin performs a simple elementwise sin function.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sin
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"sin(I)"
})
.
finalize
());
}
// sin performs a simple elementwise sin function.
template
<>
void
Impl
<
op
::
Sin
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"sin(I)"
})
.
finalize
());
}
// sinh performs a simple elementwise hyperbolic sin function.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sinh
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"sinh(I)"
})
.
finalize
());
}
// sinh performs a simple elementwise hyperbolic sin function.
template
<>
void
Impl
<
op
::
Sinh
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"sinh(I)"
})
.
finalize
());
}
// sqrt performs a simple elementwise square root function.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sqrt
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"sqrt(I)"
})
.
finalize
());
}
// sqrt performs a simple elementwise square root function.
template
<>
void
Impl
<
op
::
Sqrt
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"sqrt(I)"
})
.
finalize
());
}
// tan performs a simple elementwise tangent function.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Tan
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"tan(I)"
})
.
finalize
());
}
// tan performs a simple elementwise tangent function.
template
<>
void
Impl
<
op
::
Tan
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"tan(I)"
})
.
finalize
());
}
// tanh performs a simple elementwise hyperbolic tangent function.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Tanh
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"tanh(I)"
})
.
finalize
());
}
// tanh performs a simple elementwise hyperbolic tangent function.
template
<>
void
Impl
<
op
::
Tanh
>::
operator
()()
{
check_inputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"tanh(I)"
})
.
finalize
());
}
namespace
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Acos
>::
Registration
register_acos
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Asin
>::
Registration
register_asin
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Atan
>::
Registration
register_atan
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Cos
>::
Registration
register_cos
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Cosh
>::
Registration
register_cosh
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Exp
>::
Registration
register_exp
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Log
>::
Registration
register_log
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Power
>::
Registration
register_power
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sin
>::
Registration
register_sin
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sinh
>::
Registration
register_sinh
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sqrt
>::
Registration
register_sqrt
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Tan
>::
Registration
register_tan
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Tanh
>::
Registration
register_tanh
;
namespace
{
Impl
<
op
::
Acos
>::
Registration
register_acos
;
Impl
<
op
::
Asin
>::
Registration
register_asin
;
Impl
<
op
::
Atan
>::
Registration
register_atan
;
Impl
<
op
::
Cos
>::
Registration
register_cos
;
Impl
<
op
::
Cosh
>::
Registration
register_cosh
;
Impl
<
op
::
Exp
>::
Registration
register_exp
;
Impl
<
op
::
Log
>::
Registration
register_log
;
Impl
<
op
::
Power
>::
Registration
register_power
;
Impl
<
op
::
Sin
>::
Registration
register_sin
;
Impl
<
op
::
Sinh
>::
Registration
register_sinh
;
Impl
<
op
::
Sqrt
>::
Registration
register_sqrt
;
Impl
<
op
::
Tan
>::
Registration
register_tan
;
Impl
<
op
::
Tanh
>::
Registration
register_tanh
;
}
}
}
}
src/ngraph/runtime/plaidml/unit_test.manifest
View file @
61df6725
...
...
@@ -38,6 +38,8 @@ topk_2d_max_one # No plans to implement TopK
topk_2d_min_all # No plans to implement TopK
topk_2d_min_partial # No plans to implement TopK
topk_2d_min_one # No plans to implement TopK
topk_int64 # No plans to implement TopK
topk_5d_max_partial # No plans to implement TopK
# Tests that PlaidML might be able to run at some point.
backwards_maxpool_n2_c1_hw5_3x3_str2_max_pad1x2_2x3
...
...
@@ -84,3 +86,5 @@ sum_3d_eliminate_zero_dim # Empty dims apparently should produce shape
dot_0_0 # Empty dims apparently should produce shaped 0s
dot_matrix_2x0_0x2 # Empty dims apparently should produce shaped 0s
dot_2x0_0 # Empty dims apparently should produce shaped 0s
numeric_float_nan
numeric_double_nan
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment