Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
61df6725
Commit
61df6725
authored
Oct 31, 2018
by
Rob Earhart
Committed by
Robert Kimball
Oct 31, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PlaidML] Specialize within namespaces (for Linux) (#1948)
parent
5698fa75
Show whitespace changes
Inline
Side-by-side
Showing
22 changed files
with
936 additions
and
686 deletions
+936
-686
plaidml_ops_arithmetic.cpp
src/ngraph/runtime/plaidml/plaidml_ops_arithmetic.cpp
+91
-82
plaidml_ops_batch_norm.cpp
src/ngraph/runtime/plaidml/plaidml_ops_batch_norm.cpp
+70
-44
plaidml_ops_comparison.cpp
src/ngraph/runtime/plaidml/plaidml_ops_comparison.cpp
+58
-49
plaidml_ops_concat.cpp
src/ngraph/runtime/plaidml/plaidml_ops_concat.cpp
+26
-12
plaidml_ops_convert.cpp
src/ngraph/runtime/plaidml/plaidml_ops_convert.cpp
+18
-8
plaidml_ops_convolution.cpp
src/ngraph/runtime/plaidml/plaidml_ops_convolution.cpp
+44
-46
plaidml_ops_dot.cpp
src/ngraph/runtime/plaidml/plaidml_ops_dot.cpp
+19
-9
plaidml_ops_function.cpp
src/ngraph/runtime/plaidml/plaidml_ops_function.cpp
+18
-8
plaidml_ops_general.cpp
src/ngraph/runtime/plaidml/plaidml_ops_general.cpp
+127
-97
plaidml_ops_index_reduction.cpp
src/ngraph/runtime/plaidml/plaidml_ops_index_reduction.cpp
+50
-43
plaidml_ops_io.cpp
src/ngraph/runtime/plaidml/plaidml_ops_io.cpp
+22
-13
plaidml_ops_local_response_norm.cpp
...graph/runtime/plaidml/plaidml_ops_local_response_norm.cpp
+23
-11
plaidml_ops_logical.cpp
src/ngraph/runtime/plaidml/plaidml_ops_logical.cpp
+28
-19
plaidml_ops_one_hot.cpp
src/ngraph/runtime/plaidml/plaidml_ops_one_hot.cpp
+26
-13
plaidml_ops_pool.cpp
src/ngraph/runtime/plaidml/plaidml_ops_pool.cpp
+36
-27
plaidml_ops_reduce.cpp
src/ngraph/runtime/plaidml/plaidml_ops_reduce.cpp
+96
-82
plaidml_ops_replace_slice.cpp
src/ngraph/runtime/plaidml/plaidml_ops_replace_slice.cpp
+26
-13
plaidml_ops_reverse.cpp
src/ngraph/runtime/plaidml/plaidml_ops_reverse.cpp
+18
-9
plaidml_ops_slice.cpp
src/ngraph/runtime/plaidml/plaidml_ops_slice.cpp
+24
-11
plaidml_ops_softmax.cpp
src/ngraph/runtime/plaidml/plaidml_ops_softmax.cpp
+24
-11
plaidml_ops_transcendental.cpp
src/ngraph/runtime/plaidml/plaidml_ops_transcendental.cpp
+88
-79
unit_test.manifest
src/ngraph/runtime/plaidml/unit_test.manifest
+4
-0
No files found.
src/ngraph/runtime/plaidml/plaidml_ops_arithmetic.cpp
View file @
61df6725
...
@@ -28,10 +28,16 @@
...
@@ -28,10 +28,16 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp"
// Abs performs a simple elementwise absolute value.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Abs
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// Abs performs a simple elementwise absolute value.
template
<>
void
Impl
<
op
::
Abs
>::
operator
()()
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -39,12 +45,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Abs>::operator()()
...
@@ -39,12 +45,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Abs>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"abs(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"abs(I)"
})
.
finalize
());
.
finalize
());
}
}
// Add performs a simple elementwise addition.
// Add performs a simple elementwise addition.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Add
>::
operator
()()
void
Impl
<
op
::
Add
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -53,12 +59,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Add>::operator()()
...
@@ -53,12 +59,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Add>::operator()()
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A + B"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A + B"
})
.
finalize
());
.
finalize
());
}
}
// Ceiling performs a simple elementwise ceiling.
// Ceiling performs a simple elementwise ceiling.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Ceiling
>::
operator
()()
void
Impl
<
op
::
Ceiling
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -66,12 +72,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Ceiling>::operator()()
...
@@ -66,12 +72,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Ceiling>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"ceil(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"ceil(I)"
})
.
finalize
());
.
finalize
());
}
}
// Divide performs a simple elementwise division.
// Divide performs a simple elementwise division.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Divide
>::
operator
()()
void
Impl
<
op
::
Divide
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -80,12 +86,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Divide>::operator()()
...
@@ -80,12 +86,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Divide>::operator()()
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A / B"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A / B"
})
.
finalize
());
.
finalize
());
}
}
// Floor performs a simple elementwise floor.
// Floor performs a simple elementwise floor.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Floor
>::
operator
()()
void
Impl
<
op
::
Floor
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -93,12 +99,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Floor>::operator()()
...
@@ -93,12 +99,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Floor>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"floor(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"floor(I)"
})
.
finalize
());
.
finalize
());
}
}
// Multiply performs a simple elementwise multiplication.
// Multiply performs a simple elementwise multiplication.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Multiply
>::
operator
()()
void
Impl
<
op
::
Multiply
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -107,12 +113,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Multiply>::operator()()
...
@@ -107,12 +113,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Multiply>::operator()()
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A * B"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A * B"
})
.
finalize
());
.
finalize
());
}
}
// Negative performs a simple elementwise negation.
// Negative performs a simple elementwise negation.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Negative
>::
operator
()()
void
Impl
<
op
::
Negative
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -120,12 +126,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Negative>::operator()()
...
@@ -120,12 +126,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Negative>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"-I"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"-I"
})
.
finalize
());
.
finalize
());
}
}
// Relu implements a simple elementwise rectified linear unit.
// Relu implements a simple elementwise rectified linear unit.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Relu
>::
operator
()()
void
Impl
<
op
::
Relu
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -133,12 +139,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Relu>::operator()()
...
@@ -133,12 +139,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Relu>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"relu(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"relu(I)"
})
.
finalize
());
.
finalize
());
}
}
// ReluBackprop computes the derivative of Relu.
// ReluBackprop computes the derivative of Relu.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ReluBackprop
>::
operator
()()
void
Impl
<
op
::
ReluBackprop
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -147,12 +153,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReluBackprop>::operator()()
...
@@ -147,12 +153,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReluBackprop>::operator()()
.
add
(
builder
::
Output
{
"DI"
})
.
add
(
builder
::
Output
{
"DI"
})
.
add
(
builder
::
Elementwise
{
"DI"
,
"I > 0 ? DO : 0"
})
.
add
(
builder
::
Elementwise
{
"DI"
,
"I > 0 ? DO : 0"
})
.
finalize
());
.
finalize
());
}
}
// Sigmoid computes a standard ML sigmoid: 1/(1+exp(-X))
// Sigmoid computes a standard ML sigmoid: 1/(1+exp(-X))
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sigmoid
>::
operator
()()
void
Impl
<
op
::
Sigmoid
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -160,13 +166,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sigmoid>::operator()()
...
@@ -160,13 +166,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sigmoid>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"1/(1+exp(-I))"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"1/(1+exp(-I))"
})
.
finalize
());
.
finalize
());
}
}
// SigmoidBackprop computes the derivative of a standard ML
// SigmoidBackprop computes the derivative of a standard ML
// sigmoid: dOutput * sigmoid(X) * (1-sigmoid(X))
// sigmoid: dOutput * sigmoid(X) * (1-sigmoid(X))
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
SigmoidBackprop
>::
operator
()()
void
Impl
<
op
::
SigmoidBackprop
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -176,26 +182,27 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::SigmoidBackprop>::operator()()
...
@@ -176,26 +182,27 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::SigmoidBackprop>::operator()()
.
add
(
builder
::
Elementwise
{
"O"
,
"1/(1+exp(-I))"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"1/(1+exp(-I))"
})
.
add
(
builder
::
Elementwise
{
"DI"
,
"DO * O * (1-O)"
})
.
add
(
builder
::
Elementwise
{
"DI"
,
"DO * O * (1-O)"
})
.
finalize
());
.
finalize
());
}
}
// Sign returns the sign of an element.
// Sign returns the sign of an element.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sign
>::
operator
()()
void
Impl
<
op
::
Sign
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"S"
,
"(I < 0) ? -1 : ((I > 0) ? 1 : 0)"
})
.
add
(
builder
::
Elementwise
{
"S"
,
"(I < 0) ? -1 : ((I > 0) ? 1 : 0)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
tile_converter
(
"S"
,
op
().
get_element_type
())})
.
add
(
builder
::
Elementwise
{
"O"
,
tile_converter
(
"S"
,
op
().
get_element_type
())})
.
finalize
());
.
finalize
());
}
}
// Subtract performs a simple elementwise subtraction.
// Subtract performs a simple elementwise subtraction.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Subtract
>::
operator
()()
void
Impl
<
op
::
Subtract
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -204,22 +211,24 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Subtract>::operator()()
...
@@ -204,22 +211,24 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Subtract>::operator()()
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A - B"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A - B"
})
.
finalize
());
.
finalize
());
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Abs
>::
Registration
register_abs
;
Impl
<
op
::
Abs
>::
Registration
register_abs
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Add
>::
Registration
register_add
;
Impl
<
op
::
Add
>::
Registration
register_add
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Ceiling
>::
Registration
register_ceiling
;
Impl
<
op
::
Ceiling
>::
Registration
register_ceiling
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Divide
>::
Registration
register_divide
;
Impl
<
op
::
Divide
>::
Registration
register_divide
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Floor
>::
Registration
register_floor
;
Impl
<
op
::
Floor
>::
Registration
register_floor
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Multiply
>::
Registration
register_multiply
;
Impl
<
op
::
Multiply
>::
Registration
register_multiply
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Negative
>::
Registration
register_negative
;
Impl
<
op
::
Negative
>::
Registration
register_negative
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Relu
>::
Registration
register_relu
;
Impl
<
op
::
Relu
>::
Registration
register_relu
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ReluBackprop
>::
Registration
register_relu_backprop
;
Impl
<
op
::
ReluBackprop
>::
Registration
register_relu_backprop
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sigmoid
>::
Registration
register_sigmoid
;
Impl
<
op
::
Sigmoid
>::
Registration
register_sigmoid
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
SigmoidBackprop
>::
Registration
Impl
<
op
::
SigmoidBackprop
>::
Registration
register_sigmoid_backprop
;
register_sigmoid_backprop
;
Impl
<
op
::
Sign
>::
Registration
register_sign
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sign
>::
Registration
register_sign
;
Impl
<
op
::
Subtract
>::
Registration
register_subtract
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Subtract
>::
Registration
register_subtract
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_batch_norm.cpp
View file @
61df6725
...
@@ -18,11 +18,17 @@
...
@@ -18,11 +18,17 @@
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// BatchNormInference implements batch normalization for inference, in
namespace
ngraph
// which the mean and variance to use are supplied.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
BatchNormInference
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// BatchNormInference implements batch normalization for inference, in
// which the mean and variance to use are supplied.
template
<>
void
Impl
<
op
::
BatchNormInference
>::
operator
()()
{
auto
&
input_shape
=
op
().
get_input_shape
(
2
);
auto
&
input_shape
=
op
().
get_input_shape
(
2
);
check_inputs
(
5
);
check_inputs
(
5
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -45,12 +51,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()(
...
@@ -45,12 +51,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()(
if
(
input_shape
.
size
()
<=
2
)
if
(
input_shape
.
size
()
<=
2
)
{
{
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
"Gamma"
}).
add
(
builder
::
Elementwise
{
"BetaP"
,
"Beta"
});
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
"Gamma"
})
.
add
(
builder
::
Elementwise
{
"BetaP"
,
"Beta"
});
}
}
else
else
{
{
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
std
::
string
{
"reshape(Gamma, C"
}
+
ones
+
")"
})
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
.
add
(
builder
::
Elementwise
{
"BetaP"
,
std
::
string
{
"reshape(Beta, C"
}
+
ones
+
")"
});
std
::
string
{
"reshape(Gamma, C"
}
+
ones
+
")"
})
.
add
(
builder
::
Elementwise
{
"BetaP"
,
std
::
string
{
"reshape(Beta, C"
}
+
ones
+
")"
});
}
}
if
(
input_shape
.
size
()
<=
2
)
if
(
input_shape
.
size
()
<=
2
)
...
@@ -59,7 +68,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()(
...
@@ -59,7 +68,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()(
}
}
else
else
{
{
f
.
add
(
builder
::
Elementwise
{
"MeanP"
,
std
::
string
{
"reshape(Mean, C"
}
+
ones
+
")"
});
f
.
add
(
builder
::
Elementwise
{
"MeanP"
,
std
::
string
{
"reshape(Mean, C"
}
+
ones
+
")"
});
}
}
if
(
input_shape
.
size
()
<=
2
)
if
(
input_shape
.
size
()
<=
2
)
...
@@ -68,24 +78,26 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()(
...
@@ -68,24 +78,26 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()(
}
}
else
else
{
{
f
.
add
(
builder
::
Elementwise
{
"VarianceP"
,
std
::
string
{
"reshape(Variance, C"
}
+
ones
+
")"
});
f
.
add
(
builder
::
Elementwise
{
"VarianceP"
,
std
::
string
{
"reshape(Variance, C"
}
+
ones
+
")"
});
}
}
f
.
add
(
builder
::
Elementwise
{
"Normalized"
,
f
.
add
(
builder
::
Elementwise
{
"Normalized"
,
"(((Input-MeanP) / sqrt(VarianceP + "
+
"(((Input-MeanP) / sqrt(VarianceP + "
+
std
::
to_string
(
op
().
get_eps_value
())
+
")) * GammaP) + BetaP"
});
std
::
to_string
(
op
().
get_eps_value
())
+
")) * GammaP) + BetaP"
});
auto
app
=
f
.
finalize
();
auto
app
=
f
.
finalize
();
set_output
(
app
);
set_output
(
app
);
}
}
// BatchNormTraining implements batch normalization for training, in
// BatchNormTraining implements batch normalization for training, in
// which the mean and variance are to be computed from the supplied
// which the mean and variance are to be computed from the supplied
// input.
// input.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
BatchNormTraining
>::
operator
()()
void
Impl
<
op
::
BatchNormTraining
>::
operator
()()
{
{
auto
&
input_shape
=
op
().
get_input_shape
(
2
);
auto
&
input_shape
=
op
().
get_input_shape
(
2
);
check_inputs
(
3
);
check_inputs
(
3
);
check_outputs
(
3
);
check_outputs
(
3
);
...
@@ -108,12 +120,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
...
@@ -108,12 +120,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
if
(
input_shape
.
size
()
<=
2
)
if
(
input_shape
.
size
()
<=
2
)
{
{
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
"Gamma"
}).
add
(
builder
::
Elementwise
{
"BetaP"
,
"Beta"
});
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
"Gamma"
})
.
add
(
builder
::
Elementwise
{
"BetaP"
,
"Beta"
});
}
}
else
else
{
{
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
std
::
string
{
"reshape(Gamma, C"
}
+
ones
+
")"
})
f
.
add
(
builder
::
Elementwise
{
"GammaP"
,
.
add
(
builder
::
Elementwise
{
"BetaP"
,
std
::
string
{
"reshape(Beta, C"
}
+
ones
+
")"
});
std
::
string
{
"reshape(Gamma, C"
}
+
ones
+
")"
})
.
add
(
builder
::
Elementwise
{
"BetaP"
,
std
::
string
{
"reshape(Beta, C"
}
+
ones
+
")"
});
}
}
if
(
input_shape
.
size
()
<=
2
)
if
(
input_shape
.
size
()
<=
2
)
...
@@ -131,7 +146,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
...
@@ -131,7 +146,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
}
}
f
.
add
(
builder
::
UnaryContraction
{
"+"
}
f
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
set
(
builder
::
ContractionOutput
{
"SumInput"
}.
add_indices
({
"c"
}).
add_dims
({
"C"
}))
.
set
(
builder
::
ContractionOutput
{
"SumInput"
}.
add_indices
({
"c"
}).
add_dims
(
{
"C"
}))
.
set
(
builder
::
ContractionInput
{
"Input"
}
.
set
(
builder
::
ContractionInput
{
"Input"
}
.
add_indices
({
"b"
,
"c"
})
.
add_indices
({
"b"
,
"c"
})
.
add_indices
(
"di"
,
3
,
input_shape
.
size
()
+
1
)));
.
add_indices
(
"di"
,
3
,
input_shape
.
size
()
+
1
)));
...
@@ -143,13 +159,16 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
...
@@ -143,13 +159,16 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
}
}
else
else
{
{
f
.
add
(
builder
::
Elementwise
{
"MeanP"
,
std
::
string
{
"reshape(Mean, C"
}
+
ones
+
")"
});
f
.
add
(
builder
::
Elementwise
{
"MeanP"
,
std
::
string
{
"reshape(Mean, C"
}
+
ones
+
")"
});
}
}
f
.
add
(
builder
::
Elementwise
{
"DiffV"
,
"(Input - MeanP)"
})
f
.
add
(
builder
::
Elementwise
{
"DiffV"
,
"(Input - MeanP)"
})
.
add
(
builder
::
Elementwise
{
"SqDiffV"
,
"DiffV*DiffV"
})
.
add
(
builder
::
Elementwise
{
"SqDiffV"
,
"DiffV*DiffV"
})
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
set
(
builder
::
ContractionOutput
{
"SumSqDiffV"
}.
add_indices
({
"c"
}).
add_dims
({
"C"
}))
.
set
(
builder
::
ContractionOutput
{
"SumSqDiffV"
}
.
add_indices
({
"c"
})
.
add_dims
({
"C"
}))
.
set
(
builder
::
ContractionInput
{
"SqDiffV"
}
.
set
(
builder
::
ContractionInput
{
"SqDiffV"
}
.
add_indices
({
"b"
,
"c"
})
.
add_indices
({
"b"
,
"c"
})
.
add_indices
(
"di"
,
3
,
input_shape
.
size
()
+
1
)))
.
add_indices
(
"di"
,
3
,
input_shape
.
size
()
+
1
)))
...
@@ -161,23 +180,25 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
...
@@ -161,23 +180,25 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
}
}
else
else
{
{
f
.
add
(
builder
::
Elementwise
{
"VarianceP"
,
std
::
string
{
"reshape(Variance, C"
}
+
ones
+
")"
});
f
.
add
(
builder
::
Elementwise
{
"VarianceP"
,
std
::
string
{
"reshape(Variance, C"
}
+
ones
+
")"
});
}
}
f
.
add
(
builder
::
Elementwise
{
"Normalized"
,
f
.
add
(
builder
::
Elementwise
{
"Normalized"
,
"(((Input-MeanP) / sqrt(VarianceP + "
+
"(((Input-MeanP) / sqrt(VarianceP + "
+
std
::
to_string
(
op
().
get_eps_value
())
+
")) * GammaP) + BetaP"
});
std
::
to_string
(
op
().
get_eps_value
())
+
")) * GammaP) + BetaP"
});
auto
app
=
f
.
finalize
();
auto
app
=
f
.
finalize
();
set_output
(
0
,
app
.
get_output
(
0
));
set_output
(
0
,
app
.
get_output
(
0
));
set_output
(
1
,
app
.
get_output
(
1
));
set_output
(
1
,
app
.
get_output
(
1
));
set_output
(
2
,
app
.
get_output
(
2
));
set_output
(
2
,
app
.
get_output
(
2
));
}
}
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
BatchNormTrainingBackprop
>::
operator
()()
void
Impl
<
op
::
BatchNormTrainingBackprop
>::
operator
()()
{
{
// WARNING: I'm unconvinced that we have sufficient test converage for BatchNorm
// WARNING: I'm unconvinced that we have sufficient test converage for BatchNorm
// backprop and in particular I'm concerned that Gamma/Beta and Mean/Var could be
// backprop and in particular I'm concerned that Gamma/Beta and Mean/Var could be
// swapped without the tests catching it.
// swapped without the tests catching it.
...
@@ -232,10 +253,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
...
@@ -232,10 +253,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
.
set
(
builder
::
ContractionInput
{
"Input"
}
.
set
(
builder
::
ContractionInput
{
"Input"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)));
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)));
f
.
add
(
builder
::
Elementwise
{
"BatchMean"
,
"BatchMeanNumerator / "
+
reduction_dims
.
str
()});
f
.
add
(
builder
::
Elementwise
{
"BatchMean"
,
"BatchMeanNumerator / "
+
reduction_dims
.
str
()});
f
.
add
(
builder
::
Elementwise
{
"NegBatchMean"
,
"-BatchMean"
});
f
.
add
(
builder
::
Elementwise
{
"NegBatchMean"
,
"-BatchMean"
});
f
.
add
(
f
.
add
(
builder
::
BinaryContraction
{
"="
,
"+"
}
builder
::
BinaryContraction
{
"="
,
"+"
}
.
set
(
builder
::
ContractionOutput
{
"Deviation"
}
.
set
(
builder
::
ContractionOutput
{
"Deviation"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)
...
@@ -244,7 +265,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
...
@@ -244,7 +265,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
.
set_lhs
(
builder
::
ContractionInput
{
"Input"
}
.
set_lhs
(
builder
::
ContractionInput
{
"Input"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
))
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
))
.
set_rhs
(
builder
::
ContractionInput
{
"NegBatchMean"
}.
add_indices
({
"0"
,
"c"
,
"0"
,
"0"
})));
.
set_rhs
(
builder
::
ContractionInput
{
"NegBatchMean"
}.
add_indices
(
{
"0"
,
"c"
,
"0"
,
"0"
})));
f
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
f
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
.
set
(
builder
::
ContractionOutput
{
"BatchVarNumerator"
}
.
set
(
builder
::
ContractionOutput
{
"BatchVarNumerator"
}
.
add_indices
({
"0"
,
"c"
,
"0"
,
"0"
})
.
add_indices
({
"0"
,
"c"
,
"0"
,
"0"
})
...
@@ -255,7 +277,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
...
@@ -255,7 +277,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
.
set_rhs
(
builder
::
ContractionInput
{
"Deviation"
}
.
set_rhs
(
builder
::
ContractionInput
{
"Deviation"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)));
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)));
f
.
add
(
builder
::
Elementwise
{
"BatchVar"
,
"BatchVarNumerator / "
+
reduction_dims
.
str
()});
f
.
add
(
builder
::
Elementwise
{
"BatchVar"
,
"BatchVarNumerator / "
+
reduction_dims
.
str
()});
f
.
add
(
builder
::
Elementwise
{
"BatchStdDev"
,
"sqrt(BatchVar + "
+
epsilon
+
")"
});
f
.
add
(
builder
::
Elementwise
{
"BatchStdDev"
,
"sqrt(BatchVar + "
+
epsilon
+
")"
});
f
.
add
(
builder
::
Elementwise
{
"NormedInput"
,
"(Input - BatchMean) / BatchStdDev"
});
f
.
add
(
builder
::
Elementwise
{
"NormedInput"
,
"(Input - BatchMean) / BatchStdDev"
});
...
@@ -266,12 +289,14 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
...
@@ -266,12 +289,14 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
f
.
add
(
builder
::
Elementwise
{
"DNormedInput"
,
"DOutput * BroadcastGamma"
});
f
.
add
(
builder
::
Elementwise
{
"DNormedInput"
,
"DOutput * BroadcastGamma"
});
f
.
add
(
builder
::
UnaryContraction
{
"+"
}
f
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
set
(
builder
::
ContractionOutput
{
"SumDOutput"
}.
add_indices
({
"c"
}).
add_dims
({
"C"
}))
.
set
(
builder
::
ContractionOutput
{
"SumDOutput"
}.
add_indices
({
"c"
}).
add_dims
(
{
"C"
}))
.
set
(
builder
::
ContractionInput
{
"DOutput"
}
.
set
(
builder
::
ContractionInput
{
"DOutput"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)));
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
)));
f
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
f
.
add
(
builder
::
BinaryContraction
{
"+"
,
"*"
}
.
set
(
builder
::
ContractionOutput
{
"DGamma"
}.
add_indices
({
"c"
}).
add_dims
({
"C"
}))
.
set
(
builder
::
ContractionOutput
{
"DGamma"
}.
add_indices
({
"c"
}).
add_dims
(
{
"C"
}))
.
set_lhs
(
builder
::
ContractionInput
{
"DOutput"
}
.
set_lhs
(
builder
::
ContractionInput
{
"DOutput"
}
.
add_indices
({
"n"
,
"c"
})
.
add_indices
({
"n"
,
"c"
})
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
))
.
add_indices
(
"x"
,
3
,
input_shape
.
size
()
+
1
))
...
@@ -295,14 +320,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
...
@@ -295,14 +320,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
set_output
(
0
,
app
.
get_output
(
0
));
set_output
(
0
,
app
.
get_output
(
0
));
set_output
(
1
,
app
.
get_output
(
1
));
set_output
(
1
,
app
.
get_output
(
1
));
set_output
(
2
,
app
.
get_output
(
2
));
set_output
(
2
,
app
.
get_output
(
2
));
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
BatchNormInference
>::
Registration
Impl
<
op
::
BatchNormInference
>::
Registration
register_batch_norm_inference
;
register_batch_norm_inference
;
Impl
<
op
::
BatchNormTraining
>::
Registration
register_batch_norm_training
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
BatchNormTraining
>::
Registration
Impl
<
op
::
BatchNormTrainingBackprop
>::
Registration
register_batch_norm_training
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
BatchNormTrainingBackprop
>::
Registration
register_batch_norm_training_backprop
;
register_batch_norm_training_backprop
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_comparison.cpp
View file @
61df6725
...
@@ -24,10 +24,16 @@
...
@@ -24,10 +24,16 @@
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Equal performs a simple elementwise equality.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Equal
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// Equal performs a simple elementwise equality.
template
<>
void
Impl
<
op
::
Equal
>::
operator
()()
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -37,12 +43,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Equal>::operator()()
...
@@ -37,12 +43,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Equal>::operator()()
.
add
(
builder
::
Elementwise
{
"C"
,
"A == B"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A == B"
})
.
finalize
(),
.
finalize
(),
TensorContents
::
LOGICAL
);
TensorContents
::
LOGICAL
);
}
}
// Greater performs a simple elementwise greater-than comparison.
// Greater performs a simple elementwise greater-than comparison.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Greater
>::
operator
()()
void
Impl
<
op
::
Greater
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -52,12 +58,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Greater>::operator()()
...
@@ -52,12 +58,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Greater>::operator()()
.
add
(
builder
::
Elementwise
{
"C"
,
"A > B"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A > B"
})
.
finalize
(),
.
finalize
(),
TensorContents
::
LOGICAL
);
TensorContents
::
LOGICAL
);
}
}
// GreaterEq performs a simple elementwise greater-than-or-equal-to comparison.
// GreaterEq performs a simple elementwise greater-than-or-equal-to comparison.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
GreaterEq
>::
operator
()()
void
Impl
<
op
::
GreaterEq
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -67,12 +73,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::GreaterEq>::operator()()
...
@@ -67,12 +73,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::GreaterEq>::operator()()
.
add
(
builder
::
Elementwise
{
"C"
,
"A >= B"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A >= B"
})
.
finalize
(),
.
finalize
(),
TensorContents
::
LOGICAL
);
TensorContents
::
LOGICAL
);
}
}
// Less performs a simple elementwise less-than comparison.
// Less performs a simple elementwise less-than comparison.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Less
>::
operator
()()
void
Impl
<
op
::
Less
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -82,12 +88,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Less>::operator()()
...
@@ -82,12 +88,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Less>::operator()()
.
add
(
builder
::
Elementwise
{
"C"
,
"A < B"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A < B"
})
.
finalize
(),
.
finalize
(),
TensorContents
::
LOGICAL
);
TensorContents
::
LOGICAL
);
}
}
// LessEq performs a simple elementwise less-than-or-equal-to comparison.
// LessEq performs a simple elementwise less-than-or-equal-to comparison.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
LessEq
>::
operator
()()
void
Impl
<
op
::
LessEq
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -97,12 +103,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::LessEq>::operator()()
...
@@ -97,12 +103,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::LessEq>::operator()()
.
add
(
builder
::
Elementwise
{
"C"
,
"A <= B"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A <= B"
})
.
finalize
(),
.
finalize
(),
TensorContents
::
LOGICAL
);
TensorContents
::
LOGICAL
);
}
}
// Maximum performs a simple elementwise maximum.
// Maximum performs a simple elementwise maximum.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Maximum
>::
operator
()()
void
Impl
<
op
::
Maximum
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -111,12 +117,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Maximum>::operator()()
...
@@ -111,12 +117,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Maximum>::operator()()
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"max(A, B)"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"max(A, B)"
})
.
finalize
());
.
finalize
());
}
}
// Minimum performs a simple elementwise minimum.
// Minimum performs a simple elementwise minimum.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Minimum
>::
operator
()()
void
Impl
<
op
::
Minimum
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -125,12 +131,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Minimum>::operator()()
...
@@ -125,12 +131,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Minimum>::operator()()
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Output
{
"C"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"min(A, B)"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"min(A, B)"
})
.
finalize
());
.
finalize
());
}
}
// NotEqual performs a simple elementwise not-equality.
// NotEqual performs a simple elementwise not-equality.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
NotEqual
>::
operator
()()
void
Impl
<
op
::
NotEqual
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -140,16 +146,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::NotEqual>::operator()()
...
@@ -140,16 +146,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::NotEqual>::operator()()
.
add
(
builder
::
Elementwise
{
"C"
,
"A != B"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A != B"
})
.
finalize
(),
.
finalize
(),
TensorContents
::
LOGICAL
);
TensorContents
::
LOGICAL
);
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Equal
>::
Registration
register_equal
;
Impl
<
op
::
Equal
>::
Registration
register_equal
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Greater
>::
Registration
register_greater
;
Impl
<
op
::
Greater
>::
Registration
register_greater
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
GreaterEq
>::
Registration
register_greater_eq
;
Impl
<
op
::
GreaterEq
>::
Registration
register_greater_eq
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Less
>::
Registration
register_less
;
Impl
<
op
::
Less
>::
Registration
register_less
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
LessEq
>::
Registration
register_less_eq
;
Impl
<
op
::
LessEq
>::
Registration
register_less_eq
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Maximum
>::
Registration
register_maximum
;
Impl
<
op
::
Maximum
>::
Registration
register_maximum
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Minimum
>::
Registration
register_minimum
;
Impl
<
op
::
Minimum
>::
Registration
register_minimum
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
NotEqual
>::
Registration
register_not_equal
;
Impl
<
op
::
NotEqual
>::
Registration
register_not_equal
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_concat.cpp
View file @
61df6725
...
@@ -17,10 +17,16 @@
...
@@ -17,10 +17,16 @@
#include "ngraph/op/concat.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Concat views a tensor as a new type.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Concat
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// Concat views a tensor as a new type.
template
<>
void
Impl
<
op
::
Concat
>::
operator
()()
{
check_outputs
(
1
);
check_outputs
(
1
);
auto
f
=
start_tile_function
();
auto
f
=
start_tile_function
();
...
@@ -52,10 +58,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()()
...
@@ -52,10 +58,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()()
continue
;
continue
;
}
}
std
::
string
sidx
{
std
::
to_string
(
iidx
)};
std
::
string
sidx
{
std
::
to_string
(
iidx
)};
f
.
add
(
builder
::
Input
{
op_input
(
iidx
),
"I"
+
sidx
}.
add_dims
(
"I"
+
sidx
+
"_D"
,
0
,
dim_count
));
f
.
add
(
builder
::
Input
{
op_input
(
iidx
),
"I"
+
sidx
}.
add_dims
(
"I"
+
sidx
+
"_D"
,
0
,
dim_count
));
f
.
add
(
builder
::
UnaryContraction
{
"="
}
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"E"
+
sidx
}
.
set
(
builder
::
ContractionOutput
{
"E"
+
sidx
}
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_count
;
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_count
;
++
idx
)
{
{
std
::
ostringstream
s
;
std
::
ostringstream
s
;
...
@@ -70,19 +78,22 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()()
...
@@ -70,19 +78,22 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()()
}
}
}
}
})
})
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_count
;
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_count
;
++
idx
)
{
{
std
::
ostringstream
s
;
std
::
ostringstream
s
;
s
<<
"d"
<<
idx
;
s
<<
"d"
<<
idx
;
if
(
saw_non_zero_tensor
&&
idx
==
op
().
get_concatenation_axis
())
if
(
saw_non_zero_tensor
&&
idx
==
op
().
get_concatenation_axis
())
{
{
s
<<
" + "
<<
offset
.
str
();
s
<<
" + "
<<
offset
.
str
();
}
}
out
=
s
.
str
();
out
=
s
.
str
();
}
}
}))
}))
.
set
(
builder
::
ContractionInput
{
"I"
+
sidx
}.
add_indices
(
"d"
,
0
,
dim_count
)));
.
set
(
builder
::
ContractionInput
{
"I"
+
sidx
}.
add_indices
(
"d"
,
0
,
dim_count
)));
if
(
saw_non_zero_tensor
)
if
(
saw_non_zero_tensor
)
{
{
oexpr
<<
" + "
;
oexpr
<<
" + "
;
...
@@ -95,9 +106,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()()
...
@@ -95,9 +106,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()()
f
.
add
(
builder
::
Elementwise
{
"O"
,
oexpr
.
str
()});
f
.
add
(
builder
::
Elementwise
{
"O"
,
oexpr
.
str
()});
set_output
(
f
.
finalize
());
set_output
(
f
.
finalize
());
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Concat
>::
Registration
register_concat
;
Impl
<
op
::
Concat
>::
Registration
register_concat
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_convert.cpp
View file @
61df6725
...
@@ -18,21 +18,31 @@
...
@@ -18,21 +18,31 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp"
// Convert views a tensor as a new type.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Convert
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// Convert views a tensor as a new type.
template
<>
void
Impl
<
op
::
Convert
>::
operator
()()
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
})
.
add
(
builder
::
Input
{
op_input
(),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
.
add
(
builder
::
Elementwise
{
"O"
,
tile_converter
(
"I"
,
to_plaidml
(
op
().
get_convert_element_type
()))})
"O"
,
tile_converter
(
"I"
,
to_plaidml
(
op
().
get_convert_element_type
()))})
.
finalize
());
.
finalize
());
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Convert
>::
Registration
register_convert
;
Impl
<
op
::
Convert
>::
Registration
register_convert
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_convolution.cpp
View file @
61df6725
...
@@ -50,32 +50,29 @@ namespace ngraph
...
@@ -50,32 +50,29 @@ namespace ngraph
std
::
size_t
output_channel_axis_result
,
std
::
size_t
output_channel_axis_result
,
bool
rotate_filter
);
bool
rotate_filter
);
};
};
}
}
}
template
<>
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
Convolution
>
struct
ParentImpl
<
op
::
Convolution
>
{
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ConvolutionImpl
<
ngraph
::
op
::
Convolution
>
;
using
Type
=
ConvolutionImpl
<
op
::
Convolution
>
;
};
};
template
<>
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
ConvolutionBackpropFilters
>
struct
ParentImpl
<
op
::
ConvolutionBackpropFilters
>
{
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ConvolutionImpl
<
ngraph
::
op
::
ConvolutionBackpropFilters
>
;
using
Type
=
ConvolutionImpl
<
op
::
ConvolutionBackpropFilters
>
;
};
};
template
<>
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
ConvolutionBackpropData
>
struct
ParentImpl
<
op
::
ConvolutionBackpropData
>
{
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ConvolutionImpl
<
ngraph
::
op
::
ConvolutionBackpropData
>
;
using
Type
=
ConvolutionImpl
<
op
::
ConvolutionBackpropData
>
;
};
};
// Convolution implements a standard ML convolultion, with optional striding, padding, and dilation.
// Convolution implements a standard ML convolultion, with optional striding, padding, and dilation.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Convolution
>::
operator
()()
void
Impl
<
op
::
Convolution
>::
operator
()()
{
{
this
->
check_inputs
(
2
);
this
->
check_inputs
(
2
);
this
->
check_outputs
(
1
);
this
->
check_outputs
(
1
);
...
@@ -122,13 +119,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Convolution>::operator()()
...
@@ -122,13 +119,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Convolution>::operator()()
.
set_lhs
(
cpf
.
I_in_body
())
.
set_lhs
(
cpf
.
I_in_body
())
.
set_rhs
(
cpf
.
F_in_body
()))
.
set_rhs
(
cpf
.
F_in_body
()))
.
finalize
());
.
finalize
());
}
}
// ConvolutionBackpropFilters implements the derivative of a convolution with respect to its filter
// ConvolutionBackpropFilters implements the derivative of a convolution with respect to its filter
// input.
// input.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ConvolutionBackpropFilters
>::
operator
()()
void
Impl
<
op
::
ConvolutionBackpropFilters
>::
operator
()()
{
{
this
->
check_inputs
(
2
);
this
->
check_inputs
(
2
);
this
->
check_outputs
(
1
);
this
->
check_outputs
(
1
);
...
@@ -177,13 +174,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropFilters>::ope
...
@@ -177,13 +174,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropFilters>::ope
.
set_lhs
(
cpf
.
O_in_body
())
.
set_lhs
(
cpf
.
O_in_body
())
.
set_rhs
(
cpf
.
I_in_body
()))
.
set_rhs
(
cpf
.
I_in_body
()))
.
finalize
());
.
finalize
());
}
}
// ConvolutionBackpropData implements the derivative of a convolution with respect to its data
// ConvolutionBackpropData implements the derivative of a convolution with respect to its data
// input.
// input.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ConvolutionBackpropData
>::
operator
()()
void
Impl
<
op
::
ConvolutionBackpropData
>::
operator
()()
{
{
this
->
check_inputs
(
2
);
this
->
check_inputs
(
2
);
this
->
check_outputs
(
1
);
this
->
check_outputs
(
1
);
...
@@ -232,11 +229,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropData>::operat
...
@@ -232,11 +229,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropData>::operat
.
set_lhs
(
cpf
.
O_in_body
())
.
set_lhs
(
cpf
.
O_in_body
())
.
set_rhs
(
cpf
.
F_in_body
()))
.
set_rhs
(
cpf
.
F_in_body
()))
.
finalize
());
.
finalize
());
}
}
template
<
typename
O
>
template
<
typename
O
>
inline
void
ngraph
::
runtime
::
plaidml
::
ConvolutionImpl
<
O
>::
LogConvolution
(
inline
void
ConvolutionImpl
<
O
>::
LogConvolution
(
vertexai
::
plaidml
::
variable
image
,
vertexai
::
plaidml
::
variable
image
,
vertexai
::
plaidml
::
variable
filter
,
vertexai
::
plaidml
::
variable
filter
,
std
::
size_t
image_dims
,
std
::
size_t
image_dims
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_movement_strides
,
...
@@ -251,7 +247,7 @@ inline void ngraph::runtime::plaidml::ConvolutionImpl<O>::LogConvolution(
...
@@ -251,7 +247,7 @@ inline void ngraph::runtime::plaidml::ConvolutionImpl<O>::LogConvolution(
std
::
size_t
batch_axis_result
,
std
::
size_t
batch_axis_result
,
std
::
size_t
output_channel_axis_result
,
std
::
size_t
output_channel_axis_result
,
bool
rotate_filter
)
bool
rotate_filter
)
{
{
this
->
check_inputs
(
2
);
this
->
check_inputs
(
2
);
this
->
check_outputs
(
1
);
this
->
check_outputs
(
1
);
...
@@ -271,13 +267,15 @@ inline void ngraph::runtime::plaidml::ConvolutionImpl<O>::LogConvolution(
...
@@ -271,13 +267,15 @@ inline void ngraph::runtime::plaidml::ConvolutionImpl<O>::LogConvolution(
NGRAPH_DEBUG
<<
"batch_axis_result: "
<<
batch_axis_result
;
NGRAPH_DEBUG
<<
"batch_axis_result: "
<<
batch_axis_result
;
NGRAPH_DEBUG
<<
"output_channel_axis_result: "
<<
output_channel_axis_result
;
NGRAPH_DEBUG
<<
"output_channel_axis_result: "
<<
output_channel_axis_result
;
NGRAPH_DEBUG
<<
"rotate_filter: "
<<
rotate_filter
;
NGRAPH_DEBUG
<<
"rotate_filter: "
<<
rotate_filter
;
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Convolution
>::
Registration
register_convolution
;
Impl
<
op
::
Convolution
>::
Registration
register_convolution
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ConvolutionBackpropFilters
>::
Registration
Impl
<
op
::
ConvolutionBackpropFilters
>::
Registration
register_convolution_backprop_filters
;
register_convolution_backprop_filters
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ConvolutionBackpropData
>::
Registration
Impl
<
op
::
ConvolutionBackpropData
>::
Registration
register_convolution_backprop_data
;
register_convolution_backprop_data
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_dot.cpp
View file @
61df6725
...
@@ -20,11 +20,17 @@
...
@@ -20,11 +20,17 @@
#include "ngraph/op/dot.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Dot is a generalized dot product operation -- scalar-tensor,
namespace
ngraph
// matrix-vector, and matrix multiplication.
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Dot
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// Dot is a generalized dot product operation -- scalar-tensor,
// matrix-vector, and matrix multiplication.
template
<>
void
Impl
<
op
::
Dot
>::
operator
()()
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -40,7 +46,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Dot>::operator()()
...
@@ -40,7 +46,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Dot>::operator()()
NGRAPH_DEBUG
<<
"l_dim_mac="
<<
l_dim_mac
;
NGRAPH_DEBUG
<<
"l_dim_mac="
<<
l_dim_mac
;
NGRAPH_DEBUG
<<
"r_dim_mic="
<<
r_dim_mic
;
NGRAPH_DEBUG
<<
"r_dim_mic="
<<
r_dim_mic
;
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"L"
}
.
add
(
builder
::
Input
{
op_input
(
0
),
"L"
}
.
add_dims
(
"DL"
,
1
,
l_dim_mac
+
1
)
.
add_dims
(
"DL"
,
1
,
l_dim_mac
+
1
)
.
add_dims
(
"DC"
,
1
,
reduce_limit
+
1
))
.
add_dims
(
"DC"
,
1
,
reduce_limit
+
1
))
...
@@ -61,9 +68,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Dot>::operator()()
...
@@ -61,9 +68,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Dot>::operator()()
.
add_indices
(
"dc"
,
1
,
reduce_limit
+
1
)
.
add_indices
(
"dc"
,
1
,
reduce_limit
+
1
)
.
add_indices
(
"dr"
,
r_dim_mic
+
1
,
r_dim_limit
+
1
)))
.
add_indices
(
"dr"
,
r_dim_mic
+
1
,
r_dim_limit
+
1
)))
.
finalize
());
.
finalize
());
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Dot
>::
Registration
register_dot
;
Impl
<
op
::
Dot
>::
Registration
register_dot
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_function.cpp
View file @
61df6725
...
@@ -19,10 +19,16 @@
...
@@ -19,10 +19,16 @@
#include "ngraph/runtime/plaidml/plaidml_compiler.hpp"
#include "ngraph/runtime/plaidml/plaidml_compiler.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// FunctionCall invokes a sub-function.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
FunctionCall
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// FunctionCall invokes a sub-function.
template
<>
void
Impl
<
op
::
FunctionCall
>::
operator
()()
{
Build
b
;
Build
b
;
build
()
->
compiler
->
build
(
op
().
get_functions
()[
0
],
&
b
);
build
()
->
compiler
->
build
(
op
().
get_functions
()[
0
],
&
b
);
vertexai
::
plaidml
::
function
f
{
b
.
composer
};
vertexai
::
plaidml
::
function
f
{
b
.
composer
};
...
@@ -30,7 +36,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::FunctionCall>::operator()()
...
@@ -30,7 +36,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::FunctionCall>::operator()()
for
(
std
::
size_t
idx
=
0
;
idx
<
op
().
get_input_size
();
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
op
().
get_input_size
();
++
idx
)
{
{
auto
*
oitv
=
op
().
get_inputs
()[
idx
].
get_output
().
get_tensor_ptr
().
get
();
auto
*
oitv
=
op
().
get_inputs
()[
idx
].
get_output
().
get_tensor_ptr
().
get
();
auto
*
iitv
=
b
.
func
->
get_parameters
()[
idx
]
->
get_outputs
()[
0
].
get_tensor_ptr
().
get
();
auto
*
iitv
=
b
.
func
->
get_parameters
()[
idx
]
->
get_outputs
()[
0
].
get_tensor_ptr
().
get
();
inputs
.
emplace_back
(
b
.
input_names
.
at
(
iitv
),
build
()
->
bindings
.
at
(
oitv
).
var
);
inputs
.
emplace_back
(
b
.
input_names
.
at
(
iitv
),
build
()
->
bindings
.
at
(
oitv
).
var
);
}
}
vertexai
::
plaidml
::
application
app
{
f
.
apply
(
inputs
)};
vertexai
::
plaidml
::
application
app
{
f
.
apply
(
inputs
)};
...
@@ -39,9 +46,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::FunctionCall>::operator()()
...
@@ -39,9 +46,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::FunctionCall>::operator()()
auto
*
iotv
=
b
.
func
->
get_results
()[
idx
]
->
get_output_tensor_ptr
().
get
();
auto
*
iotv
=
b
.
func
->
get_results
()[
idx
]
->
get_output_tensor_ptr
().
get
();
set_output
(
idx
,
app
.
get_output
(
b
.
output_names
[
iotv
]));
set_output
(
idx
,
app
.
get_output
(
b
.
output_names
[
iotv
]));
}
}
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
FunctionCall
>::
Registration
register_function_call
;
Impl
<
op
::
FunctionCall
>::
Registration
register_function_call
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_general.cpp
View file @
61df6725
...
@@ -28,10 +28,16 @@
...
@@ -28,10 +28,16 @@
namespace
vp
=
vertexai
::
plaidml
;
namespace
vp
=
vertexai
::
plaidml
;
// Broadcast broadcasts a tensor to a wider shape.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Broadcast
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// Broadcast broadcasts a tensor to a wider shape.
template
<>
void
Impl
<
op
::
Broadcast
>::
operator
()()
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -57,15 +63,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Broadcast>::operator()()
...
@@ -57,15 +63,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Broadcast>::operator()()
start_tile_function
()
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
}.
add_rdims
(
"D"
,
in_dim_limit
,
0
))
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
}.
add_rdims
(
"D"
,
in_dim_limit
,
0
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
add
(
.
set
(
builder
::
ContractionOutput
{
"O"
}
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_rindices
(
"o"
,
out_dim_limit
,
0
)
.
add_rindices
(
"o"
,
out_dim_limit
,
0
)
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
out_dim_limit
;
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
out_dim_limit
;
++
idx
)
{
{
if
(
op
().
get_broadcast_axes
().
count
(
idx
))
if
(
op
().
get_broadcast_axes
().
count
(
idx
))
{
{
out
=
std
::
to_string
(
op
().
get_broadcast_shape
()[
idx
]);
out
=
std
::
to_string
(
op
().
get_broadcast_shape
()[
idx
]);
}
}
else
else
{
{
...
@@ -81,12 +91,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Broadcast>::operator()()
...
@@ -81,12 +91,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Broadcast>::operator()()
}
}
})))
})))
.
finalize
());
.
finalize
());
}
}
// Constant fills in a tensor constant.
// Constant fills in a tensor constant.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Constant
>::
operator
()()
void
Impl
<
op
::
Constant
>::
operator
()()
{
{
check_inputs
(
0
);
check_inputs
(
0
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -105,48 +115,51 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Constant>::operator()()
...
@@ -105,48 +115,51 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Constant>::operator()()
switch
(
to_plaidml
(
op
().
get_element_type
()))
switch
(
to_plaidml
(
op
().
get_element_type
()))
{
{
case
PLAIDML_DATA_BOOLEAN
:
case
PLAIDML_DATA_BOOLEAN
:
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
char
*>
(
op
().
get_data_ptr
())));
set_output
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
char
*>
(
op
().
get_data_ptr
())));
return
;
return
;
case
PLAIDML_DATA_INT8
:
case
PLAIDML_DATA_INT8
:
set_output
(
set_output
(
static_cast
<
std
::
int64_t
>
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
int8_t
*>
(
op
().
get_data_ptr
())));
*
static_cast
<
const
std
::
int8_t
*>
(
op
().
get_data_ptr
())));
return
;
return
;
case
PLAIDML_DATA_INT16
:
case
PLAIDML_DATA_INT16
:
set_output
(
set_output
(
static_cast
<
std
::
int64_t
>
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
int16_t
*>
(
op
().
get_data_ptr
())));
*
static_cast
<
const
std
::
int16_t
*>
(
op
().
get_data_ptr
())));
return
;
return
;
case
PLAIDML_DATA_INT32
:
case
PLAIDML_DATA_INT32
:
set_output
(
set_output
(
static_cast
<
std
::
int64_t
>
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
int32_t
*>
(
op
().
get_data_ptr
())));
*
static_cast
<
const
std
::
int32_t
*>
(
op
().
get_data_ptr
())));
return
;
return
;
case
PLAIDML_DATA_INT64
:
case
PLAIDML_DATA_INT64
:
set_output
(
*
static_cast
<
const
std
::
int64_t
*>
(
op
().
get_data_ptr
()));
set_output
(
*
static_cast
<
const
std
::
int64_t
*>
(
op
().
get_data_ptr
()));
return
;
return
;
case
PLAIDML_DATA_UINT8
:
case
PLAIDML_DATA_UINT8
:
set_output
(
set_output
(
static_cast
<
std
::
int64_t
>
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
uint8_t
*>
(
op
().
get_data_ptr
())));
*
static_cast
<
const
std
::
uint8_t
*>
(
op
().
get_data_ptr
())));
return
;
return
;
case
PLAIDML_DATA_UINT16
:
case
PLAIDML_DATA_UINT16
:
set_output
(
set_output
(
static_cast
<
std
::
int64_t
>
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
uint16_t
*>
(
op
().
get_data_ptr
())));
*
static_cast
<
const
std
::
uint16_t
*>
(
op
().
get_data_ptr
())));
return
;
return
;
case
PLAIDML_DATA_UINT32
:
case
PLAIDML_DATA_UINT32
:
set_output
(
set_output
(
static_cast
<
std
::
int64_t
>
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
uint32_t
*>
(
op
().
get_data_ptr
())));
*
static_cast
<
const
std
::
uint32_t
*>
(
op
().
get_data_ptr
())));
return
;
return
;
case
PLAIDML_DATA_UINT64
:
case
PLAIDML_DATA_UINT64
:
set_output
(
set_output
(
static_cast
<
std
::
int64_t
>
(
static_cast
<
std
::
int64_t
>
(
*
static_cast
<
const
std
::
uint64_t
*>
(
op
().
get_data_ptr
())));
*
static_cast
<
const
std
::
uint64_t
*>
(
op
().
get_data_ptr
())));
return
;
return
;
case
PLAIDML_DATA_FLOAT16
:
case
PLAIDML_DATA_FLOAT16
:
set_output
(
static_cast
<
double
>
(
set_output
(
static_cast
<
double
>
(
static_cast
<
float
>
(
*
static_cast
<
const
half
*>
(
op
().
get_data_ptr
()))));
static_cast
<
float
>
(
*
static_cast
<
const
half
*>
(
op
().
get_data_ptr
()))));
return
;
return
;
case
PLAIDML_DATA_FLOAT32
:
case
PLAIDML_DATA_FLOAT32
:
set_output
(
static_cast
<
double
>
(
*
static_cast
<
const
float
*>
(
op
().
get_data_ptr
())));
set_output
(
static_cast
<
double
>
(
*
static_cast
<
const
float
*>
(
op
().
get_data_ptr
())));
return
;
return
;
case
PLAIDML_DATA_FLOAT64
:
case
PLAIDML_DATA_FLOAT64
:
set_output
(
static_cast
<
double
>
(
*
static_cast
<
const
double
*>
(
op
().
get_data_ptr
())));
set_output
(
static_cast
<
double
>
(
*
static_cast
<
const
double
*>
(
op
().
get_data_ptr
())));
return
;
return
;
default
:
break
;
default
:
break
;
}
}
...
@@ -163,22 +176,22 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Constant>::operator()()
...
@@ -163,22 +176,22 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Constant>::operator()()
}
}
set_output
(
tensor
);
set_output
(
tensor
);
}
}
// GetOutputElement pipes one of its N inputs to its output.
// GetOutputElement pipes one of its N inputs to its output.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
GetOutputElement
>::
operator
()()
void
Impl
<
op
::
GetOutputElement
>::
operator
()()
{
{
check_inputs_ge
(
op
().
get_n
()
+
1
);
check_inputs_ge
(
op
().
get_n
()
+
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
op_input
(
op
().
get_n
()));
set_output
(
op_input
(
op
().
get_n
()));
}
}
// Pad adds interior and exterior padding to a tensor.
// Pad adds interior and exterior padding to a tensor.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Pad
>::
operator
()()
void
Impl
<
op
::
Pad
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -214,7 +227,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
...
@@ -214,7 +227,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
auto
out_dsize
=
[
&
](
std
::
size_t
idx
)
{
auto
out_dsize
=
[
&
](
std
::
size_t
idx
)
{
std
::
ostringstream
s
;
std
::
ostringstream
s
;
std
::
size_t
total_pad
=
op
().
get_padding_below
().
at
(
idx
)
+
op
().
get_padding_above
().
at
(
idx
);
std
::
size_t
total_pad
=
op
().
get_padding_below
().
at
(
idx
)
+
op
().
get_padding_above
().
at
(
idx
);
std
::
size_t
in_dsize
=
op
().
get_input_shape
(
0
).
at
(
idx
);
std
::
size_t
in_dsize
=
op
().
get_input_shape
(
0
).
at
(
idx
);
if
(
in_dsize
)
if
(
in_dsize
)
{
{
...
@@ -267,40 +281,48 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
...
@@ -267,40 +281,48 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
if
(
!
any_zero_dims
)
if
(
!
any_zero_dims
)
{
{
f
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
}.
add_dims
(
"DI"
,
1
,
dim_limit
+
1
))
f
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
}.
add_dims
(
"DI"
,
1
,
dim_limit
+
1
))
.
add
(
builder
::
UnaryContraction
{
"="
}
.
add
(
.
set
(
builder
::
ContractionOutput
{
"P"
}
builder
::
UnaryContraction
{
"="
}
.
add_indices
(
.
set
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
builder
::
ContractionOutput
{
"P"
}
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
{
out
=
out_didx
(
idx
);
out
=
out_didx
(
idx
);
}
}
})
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
{
out
=
out_dsize
(
idx
);
out
=
out_dsize
(
idx
);
}
}
}))
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
1
,
dim_limit
+
1
)))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
1
,
dim_limit
+
1
)))
.
add
(
builder
::
Elementwise
{
"T"
,
"1"
})
.
add
(
builder
::
Elementwise
{
"T"
,
"1"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
add
(
.
set
(
builder
::
ContractionOutput
{
"F"
}
builder
::
UnaryContraction
{
"="
}
.
add_indices
(
.
set
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
builder
::
ContractionOutput
{
"F"
}
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
{
out
=
out_didx
(
idx
);
out
=
out_didx
(
idx
);
}
}
})
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
{
out
=
out_dsize
(
idx
);
out
=
out_dsize
(
idx
);
}
}
}))
}))
.
set
(
builder
::
ContractionInput
{
"T"
})
.
set
(
builder
::
ContractionInput
{
"T"
})
.
add_constraints
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_constraints
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
{
out
=
flag_constraints
(
idx
);
out
=
flag_constraints
(
idx
);
...
@@ -313,7 +335,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
...
@@ -313,7 +335,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
f
.
add
(
builder
::
UnaryContraction
{
"="
}
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"d"
,
0
,
dim_limit
)
.
add_indices
(
"d"
,
0
,
dim_limit
)
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
{
out
=
out_dsize
(
idx
);
out
=
out_dsize
(
idx
);
...
@@ -323,12 +346,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
...
@@ -323,12 +346,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
}
}
set_output
(
f
.
finalize
());
set_output
(
f
.
finalize
());
}
}
// Reshape reshapes an input tensor.
// Reshape reshapes an input tensor.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Reshape
>::
operator
()()
void
Impl
<
op
::
Reshape
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -351,13 +374,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()()
...
@@ -351,13 +374,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()()
.
add
(
builder
::
UnaryContraction
{
"="
}
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"d"
,
0
,
out_shape
.
size
())
.
add_indices
(
"d"
,
0
,
out_shape
.
size
())
.
add_dims
(
.
add_dims
([
&
]
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
std
::
transform
(
out
)
{
out_shape
.
begin
(),
std
::
transform
(
out_shape
.
begin
(),
out_shape
.
end
(),
out_shape
.
end
(),
out
,
out
,
[](
std
::
size_t
sz
)
{
return
std
::
to_string
(
sz
);
});
[](
std
::
size_t
sz
)
{
return
std
::
to_string
(
sz
);
});
}))
}))
.
set
(
builder
::
ContractionInput
{
"I"
}))
.
set
(
builder
::
ContractionInput
{
"I"
}))
.
finalize
());
.
finalize
());
...
@@ -374,28 +399,31 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()()
...
@@ -374,28 +399,31 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()()
// it's also rearranging the elements of the input tensor. This is pretty easy to
// it's also rearranging the elements of the input tensor. This is pretty easy to
// handle with a contraction.
// handle with a contraction.
src
=
src
=
start_tile_function
()
start_tile_function
()
.
add
(
builder
::
Input
{
src
,
"I"
}.
add_dims
(
"D"
,
1
,
dim_limit
+
1
))
.
add
(
builder
::
Input
{
src
,
"I"
}.
add_dims
(
"D"
,
1
,
dim_limit
+
1
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
.
add
(
builder
::
UnaryContraction
{
"="
}
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
([
&
](
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
{
out
=
"d"
+
std
::
to_string
(
input_order
[
idx
]
+
1
);
out
=
"d"
+
std
::
to_string
(
input_order
[
idx
]
+
1
);
}
}
})
})
.
add_dims
([
&
](
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
{
out
=
"D"
+
std
::
to_string
(
input_order
[
idx
]
+
1
);
out
=
"D"
+
std
::
to_string
(
input_order
[
idx
]
+
1
);
}
}
}))
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
1
,
dim_limit
+
1
)))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
1
,
dim_limit
+
1
)))
.
finalize
();
.
finalize
();
break
;
break
;
}
}
...
@@ -414,12 +442,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()()
...
@@ -414,12 +442,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
(
"O"
,
reshape_expr
.
str
()))
.
add
(
builder
::
Elementwise
(
"O"
,
reshape_expr
.
str
()))
.
finalize
());
.
finalize
());
}
}
// Select conditionally selects elements from input tensors.
// Select conditionally selects elements from input tensors.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Select
>::
operator
()()
void
Impl
<
op
::
Select
>::
operator
()()
{
{
check_inputs
(
3
);
check_inputs
(
3
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -430,26 +458,28 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Select>::operator()()
...
@@ -430,26 +458,28 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Select>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"C ? T : F"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"C ? T : F"
})
.
finalize
());
.
finalize
());
}
}
// Used by nGraph for bprop graph generation, no-op as a kernel
// Used by nGraph for bprop graph generation, no-op as a kernel
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
StopGradient
>::
operator
()()
void
Impl
<
op
::
StopGradient
>::
operator
()()
{
{
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"0"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"0"
})
.
finalize
());
.
finalize
());
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Broadcast
>::
Registration
register_broadcast
;
Impl
<
op
::
Broadcast
>::
Registration
register_broadcast
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Constant
>::
Registration
register_constant
;
Impl
<
op
::
Constant
>::
Registration
register_constant
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
GetOutputElement
>::
Registration
Impl
<
op
::
GetOutputElement
>::
Registration
register_get_output_element
;
register_get_output_element
;
Impl
<
op
::
Pad
>::
Registration
register_pad
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Pad
>::
Registration
register_pad
;
Impl
<
op
::
Reshape
>::
Registration
register_reshape
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Reshape
>::
Registration
register_reshape
;
Impl
<
op
::
Select
>::
Registration
register_select
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Select
>::
Registration
register_select
;
Impl
<
op
::
StopGradient
>::
Registration
register_stop_gradient
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
StopGradient
>::
Registration
register_stop_gradient
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_index_reduction.cpp
View file @
61df6725
...
@@ -36,13 +36,10 @@ namespace ngraph
...
@@ -36,13 +36,10 @@ namespace ngraph
void
build_index_reduction
(
const
char
*
agg_op
);
void
build_index_reduction
(
const
char
*
agg_op
);
};
};
}
}
}
template
<
typename
O
>
template
<
typename
O
>
void
ngraph
::
runtime
::
plaidml
::
IndexReductionImpl
<
O
>::
build_index_reduction
(
const
char
*
agg_op
)
void
IndexReductionImpl
<
O
>::
build_index_reduction
(
const
char
*
agg_op
)
{
{
this
->
check_inputs
(
1
);
this
->
check_inputs
(
1
);
this
->
check_outputs
(
1
);
this
->
check_outputs
(
1
);
...
@@ -56,16 +53,20 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
...
@@ -56,16 +53,20 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
// Compute the maxes along the specified axis in the input
.
add
(
// Compute the maxes along the specified axis in the input
builder
::
UnaryContraction
{
agg_op
}
builder
::
UnaryContraction
{
agg_op
}
.
set
(
builder
::
ContractionOutput
{
"SelVal"
}
.
set
(
builder
::
ContractionOutput
{
"SelVal"
}
.
add_indices
([
&
](
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
{
out
=
(
idx
==
this
->
op
().
get_reduction_axis
()
?
"rd"
:
"d"
)
+
out
=
(
idx
==
this
->
op
().
get_reduction_axis
()
?
"rd"
:
"d"
)
+
std
::
to_string
(
idx
);
std
::
to_string
(
idx
);
}
}
})
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
{
if
(
idx
==
this
->
op
().
get_reduction_axis
())
if
(
idx
==
this
->
op
().
get_reduction_axis
())
...
@@ -82,13 +83,14 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
...
@@ -82,13 +83,14 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
.
add
(
// Compare the input against the (broadcasted) max values, and select the indices
.
add
(
// Compare the input against the (broadcasted) max values, and select the indices
// where the max val occurs
// where the max val occurs
builder
::
Elementwise
{
"SelValIdxs"
,
builder
::
Elementwise
{
"SelValIdxs"
,
"I == SelVal ? index(I, "
+
reduction_axis_str
+
") : D"
+
"I == SelVal ? index(I, "
+
reduction_axis_str
+
reduction_axis_str
})
") : D"
+
reduction_axis_str
})
.
add
(
// Select the maximum index
.
add
(
// Select the maximum index
builder
::
UnaryContraction
{
"<"
}
builder
::
UnaryContraction
{
"<"
}
.
set
(
builder
::
ContractionOutput
{
"SelIdx"
}
.
set
(
.
add_indices
(
builder
::
ContractionOutput
{
"SelIdx"
}
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
{
if
(
idx
!=
this
->
op
().
get_reduction_axis
())
if
(
idx
!=
this
->
op
().
get_reduction_axis
())
...
@@ -97,7 +99,8 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
...
@@ -97,7 +99,8 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
}
}
}
}
})
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
for
(
auto
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
{
if
(
idx
!=
this
->
op
().
get_reduction_axis
())
if
(
idx
!=
this
->
op
().
get_reduction_axis
())
...
@@ -106,41 +109,45 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
...
@@ -106,41 +109,45 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
}
}
}
}
}))
}))
.
set
(
builder
::
ContractionInput
{
"SelValIdxs"
}.
add_indices
(
"d"
,
0
,
dim_limit
)))
.
set
(
builder
::
ContractionInput
{
"SelValIdxs"
}.
add_indices
(
"d"
,
0
,
dim_limit
)))
.
add
(
// Convert to the requested output element type (if any)
.
add
(
// Convert to the requested output element type (if any)
builder
::
Elementwise
{
"O"
,
builder
::
Elementwise
{
tile_converter
(
"SelIdx"
,
this
->
op
().
get_index_element_type
())})
"O"
,
tile_converter
(
"SelIdx"
,
this
->
op
().
get_index_element_type
())})
.
finalize
());
.
finalize
());
}
}
template
<>
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
ArgMax
>
struct
ParentImpl
<
op
::
ArgMax
>
{
{
using
Type
=
IndexReductionImpl
<
ngraph
::
op
::
ArgMax
>
;
using
Type
=
IndexReductionImpl
<
op
::
ArgMax
>
;
};
};
template
<>
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
ArgMin
>
struct
ParentImpl
<
op
::
ArgMin
>
{
{
using
Type
=
IndexReductionImpl
<
ngraph
::
op
::
ArgMin
>
;
using
Type
=
IndexReductionImpl
<
op
::
ArgMin
>
;
};
};
// ArgMax computes the maximum index along a tensor axis.
// ArgMax computes the maximum index along a tensor axis.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ArgMax
>::
operator
()()
void
Impl
<
op
::
ArgMax
>::
operator
()()
{
{
build_index_reduction
(
">"
);
build_index_reduction
(
">"
);
}
}
// ArgMin computes the minimum index along a tensor axis.
// ArgMin computes the minimum index along a tensor axis.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ArgMin
>::
operator
()()
void
Impl
<
op
::
ArgMin
>::
operator
()()
{
{
build_index_reduction
(
"<"
);
build_index_reduction
(
"<"
);
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ArgMax
>::
Registration
register_argmax
;
Impl
<
op
::
ArgMax
>::
Registration
register_argmax
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ArgMin
>::
Registration
register_argmin
;
Impl
<
op
::
ArgMin
>::
Registration
register_argmin
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_io.cpp
View file @
61df6725
...
@@ -20,10 +20,16 @@
...
@@ -20,10 +20,16 @@
namespace
vp
=
vertexai
::
plaidml
;
namespace
vp
=
vertexai
::
plaidml
;
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Parameter
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder.
template
<>
void
Impl
<
op
::
Parameter
>::
operator
()()
{
check_inputs
(
0
);
check_inputs
(
0
);
check_outputs
(
1
);
check_outputs
(
1
);
vp
::
placeholder
ph
{
build
()
->
io_dim_override
?
build
()
->
io_dim_override_count
vp
::
placeholder
ph
{
build
()
->
io_dim_override
?
build
()
->
io_dim_override_count
...
@@ -33,22 +39,25 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Parameter>::operator()()
...
@@ -33,22 +39,25 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Parameter>::operator()()
build
()
->
bindings
.
emplace
(
tv
,
TensorInfo
{
ph
,
TensorContents
::
DATA
});
build
()
->
bindings
.
emplace
(
tv
,
TensorInfo
{
ph
,
TensorContents
::
DATA
});
build
()
->
composer
.
input
(
name
,
ph
);
build
()
->
composer
.
input
(
name
,
ph
);
build
()
->
input_names
.
emplace
(
tv
,
std
::
move
(
name
));
build
()
->
input_names
.
emplace
(
tv
,
std
::
move
(
name
));
}
}
// Result binds a PlaidML variable to a composed function output.
// Result binds a PlaidML variable to a composed function output.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Result
>::
operator
()()
void
Impl
<
op
::
Result
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
std
::
string
name
=
std
::
string
{
"O"
}
+
std
::
to_string
(
build
()
->
output_names
.
size
());
std
::
string
name
=
std
::
string
{
"O"
}
+
std
::
to_string
(
build
()
->
output_names
.
size
());
descriptor
::
Tensor
*
tv
=
op
().
get_output_tensor_ptr
().
get
();
descriptor
::
Tensor
*
tv
=
op
().
get_output_tensor_ptr
().
get
();
build
()
->
composer
.
output
(
name
,
op_input
());
build
()
->
composer
.
output
(
name
,
op_input
());
build
()
->
output_names
.
emplace
(
tv
,
std
::
move
(
name
));
build
()
->
output_names
.
emplace
(
tv
,
std
::
move
(
name
));
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Parameter
>::
Registration
register_parameter
;
Impl
<
op
::
Parameter
>::
Registration
register_parameter
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Result
>::
Registration
register_result
;
Impl
<
op
::
Result
>::
Registration
register_result
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_local_response_norm.cpp
View file @
61df6725
...
@@ -17,21 +17,29 @@
...
@@ -17,21 +17,29 @@
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// LRN implements Local Response Normalization
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
LRN
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// LRN implements Local Response Normalization
template
<>
void
Impl
<
op
::
LRN
>::
operator
()()
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
auto
dim_limit
=
op
().
get_inputs
()[
0
].
get_shape
().
size
();
auto
dim_limit
=
op
().
get_inputs
()[
0
].
get_shape
().
size
();
auto
rank
=
dim_limit
-
2
;
auto
rank
=
dim_limit
-
2
;
auto
distance
=
op
().
get_nsize
()
/
2
;
auto
distance
=
op
().
get_nsize
()
/
2
;
std
::
ostringstream
div_expr
;
std
::
ostringstream
div_expr
;
div_expr
<<
"I / pow("
<<
op
().
get_bias
()
<<
".0 + (("
<<
op
().
get_alpha
()
<<
".0 / "
div_expr
<<
"I / pow("
<<
op
().
get_bias
()
<<
".0 + (("
<<
op
().
get_alpha
()
<<
op
().
get_nsize
()
<<
".0) * S), "
<<
op
().
get_beta
()
<<
".0)"
;
<<
".0 / "
<<
op
().
get_nsize
()
<<
".0) * S), "
<<
op
().
get_beta
()
<<
".0)"
;
set_output
(
set_output
(
start_tile_function
()
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
}.
add_dims
({
"N"
,
"C"
}).
add_dims
(
"D"
,
0
,
rank
))
.
add
(
builder
::
Input
{
op_input
(),
"I"
}
.
add_dims
({
"N"
,
"C"
})
.
add_dims
(
"D"
,
0
,
rank
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"ISQ"
,
"I * I"
})
.
add
(
builder
::
Elementwise
{
"ISQ"
,
"I * I"
})
.
add
(
builder
::
UnaryContraction
{
"+"
}
.
add
(
builder
::
UnaryContraction
{
"+"
}
...
@@ -43,14 +51,18 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::LRN>::operator()()
...
@@ -43,14 +51,18 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::LRN>::operator()()
.
set
(
builder
::
ContractionInput
{
"ISQ"
}
.
set
(
builder
::
ContractionInput
{
"ISQ"
}
.
add_indices
({
"n"
,
"c + z - "
+
std
::
to_string
(
distance
)})
.
add_indices
({
"n"
,
"c + z - "
+
std
::
to_string
(
distance
)})
.
add_indices
(
"d"
,
0
,
rank
))
.
add_indices
(
"d"
,
0
,
rank
))
.
add_constraints
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_constraints
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
out
=
"z < "
+
std
::
to_string
(
op
().
get_nsize
());
out
=
"z < "
+
std
::
to_string
(
op
().
get_nsize
());
}))
}))
.
add
(
builder
::
Elementwise
{
"O"
,
div_expr
.
str
()})
.
add
(
builder
::
Elementwise
{
"O"
,
div_expr
.
str
()})
.
finalize
());
.
finalize
());
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
LRN
>::
Registration
register_local_response_norm
;
Impl
<
op
::
LRN
>::
Registration
register_local_response_norm
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_logical.cpp
View file @
61df6725
...
@@ -19,10 +19,16 @@
...
@@ -19,10 +19,16 @@
#include "ngraph/op/or.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// And performs a simple elementwise logical and.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
And
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// And performs a simple elementwise logical and.
template
<>
void
Impl
<
op
::
And
>::
operator
()()
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -32,12 +38,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::And>::operator()()
...
@@ -32,12 +38,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::And>::operator()()
.
add
(
builder
::
Elementwise
{
"C"
,
"A ? B : A"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A ? B : A"
})
.
finalize
(),
.
finalize
(),
TensorContents
::
LOGICAL
);
TensorContents
::
LOGICAL
);
}
}
// Not performs a simple elementwise logical not.
// Not performs a simple elementwise logical not.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Not
>::
operator
()()
void
Impl
<
op
::
Not
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -46,12 +52,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Not>::operator()()
...
@@ -46,12 +52,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Not>::operator()()
.
add
(
builder
::
Elementwise
{
"O"
,
"cmp_eq(I, 0)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"cmp_eq(I, 0)"
})
.
finalize
(),
.
finalize
(),
TensorContents
::
LOGICAL
);
TensorContents
::
LOGICAL
);
}
}
// Or performs a simple elementwise logical or.
// Or performs a simple elementwise logical or.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Or
>::
operator
()()
void
Impl
<
op
::
Or
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -61,11 +67,14 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Or>::operator()()
...
@@ -61,11 +67,14 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Or>::operator()()
.
add
(
builder
::
Elementwise
{
"C"
,
"A ? A : B"
})
.
add
(
builder
::
Elementwise
{
"C"
,
"A ? A : B"
})
.
finalize
(),
.
finalize
(),
TensorContents
::
LOGICAL
);
TensorContents
::
LOGICAL
);
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
And
>::
Registration
register_and
;
Impl
<
op
::
And
>::
Registration
register_and
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Not
>::
Registration
register_not
;
Impl
<
op
::
Not
>::
Registration
register_not
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Or
>::
Registration
register_or
;
Impl
<
op
::
Or
>::
Registration
register_or
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_one_hot.cpp
View file @
61df6725
...
@@ -20,10 +20,16 @@
...
@@ -20,10 +20,16 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp"
// OneHot performs one-hot encoding along the requested axis.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
OneHot
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// OneHot performs one-hot encoding along the requested axis.
template
<>
void
Impl
<
op
::
OneHot
>::
operator
()()
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -68,9 +74,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::OneHot>::operator()()
...
@@ -68,9 +74,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::OneHot>::operator()()
.
add
(
builder
::
Input
{
op_input
(),
"I"
}.
add_dims
(
"D"
,
0
,
in_shape
.
size
()))
.
add
(
builder
::
Input
{
op_input
(),
"I"
}.
add_dims
(
"D"
,
0
,
in_shape
.
size
()))
.
add
(
builder
::
Input
{
static_cast
<
std
::
int64_t
>
(
0
),
"Zero"
})
.
add
(
builder
::
Input
{
static_cast
<
std
::
int64_t
>
(
0
),
"Zero"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
add
(
.
set
(
builder
::
ContractionOutput
{
"ZS"
}
builder
::
UnaryContraction
{
"="
}
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
set
(
builder
::
ContractionOutput
{
"ZS"
}
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
out_shape
.
size
();
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
out_shape
.
size
();
++
idx
)
{
{
if
(
idx
==
op
().
get_one_hot_axis
())
if
(
idx
==
op
().
get_one_hot_axis
())
...
@@ -85,15 +94,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::OneHot>::operator()()
...
@@ -85,15 +94,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::OneHot>::operator()()
})
})
.
add_indices
(
"d"
,
0
,
out_shape
.
size
()))
.
add_indices
(
"d"
,
0
,
out_shape
.
size
()))
.
set
(
builder
::
ContractionInput
{
"Zero"
}))
.
set
(
builder
::
ContractionInput
{
"Zero"
}))
.
add
(
builder
::
Elementwise
{
"Idx"
,
.
add
(
builder
::
Elementwise
{
"index(ZS, "
+
std
::
to_string
(
op
().
get_one_hot_axis
())
+
")"
})
"Idx"
,
"index(ZS, "
+
std
::
to_string
(
op
().
get_one_hot_axis
())
+
")"
})
.
add
(
builder
::
Elementwise
{
"IS"
,
"reshape(I, "
+
in_reshape
.
str
()
+
")"
})
.
add
(
builder
::
Elementwise
{
"IS"
,
"reshape(I, "
+
in_reshape
.
str
()
+
")"
})
.
add
(
builder
::
Elementwise
{
"OV"
,
"IS == Idx ? 1 : 0"
})
.
add
(
builder
::
Elementwise
{
"OV"
,
"IS == Idx ? 1 : 0"
})
.
add
(
builder
::
Elementwise
{
"O"
,
tile_converter
(
"OV"
,
op
().
get_element_type
())})
.
add
(
builder
::
Elementwise
{
"O"
,
tile_converter
(
"OV"
,
op
().
get_element_type
())})
.
finalize
());
.
finalize
());
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
OneHot
>::
Registration
register_one_hot
;
Impl
<
op
::
OneHot
>::
Registration
register_one_hot
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_pool.cpp
View file @
61df6725
...
@@ -20,10 +20,16 @@
...
@@ -20,10 +20,16 @@
#include "ngraph/runtime/plaidml/plaidml_convpool_formatter.hpp"
#include "ngraph/runtime/plaidml/plaidml_convpool_formatter.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// AvgPool implements a batch average pooling operation.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
AvgPool
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// AvgPool implements a batch average pooling operation.
template
<>
void
Impl
<
op
::
AvgPool
>::
operator
()()
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -92,12 +98,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPool>::operator()()
...
@@ -92,12 +98,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPool>::operator()()
f
.
add
(
cpf
.
PoolContraction
()).
add
(
builder
::
Elementwise
{
"O"
,
"S / Count"
});
f
.
add
(
cpf
.
PoolContraction
()).
add
(
builder
::
Elementwise
{
"O"
,
"S / Count"
});
set_output
(
f
.
finalize
());
set_output
(
f
.
finalize
());
}
}
// MaxPool implements a batch max pooling operation.
// MaxPool implements a batch max pooling operation.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
MaxPool
>::
operator
()()
void
Impl
<
op
::
MaxPool
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -156,11 +162,11 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::MaxPool>::operator()()
...
@@ -156,11 +162,11 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::MaxPool>::operator()()
.
add
(
cpf
.
O_out_header
())
.
add
(
cpf
.
O_out_header
())
.
add
(
cpf
.
PoolContraction
())
.
add
(
cpf
.
PoolContraction
())
.
finalize
());
.
finalize
());
}
}
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
AvgPoolBackprop
>::
operator
()()
void
Impl
<
op
::
AvgPoolBackprop
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -174,7 +180,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPoolBackprop>::operator()()
...
@@ -174,7 +180,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPoolBackprop>::operator()()
if
(
include_padding
)
if
(
include_padding
)
{
{
throw
std
::
runtime_error
(
"Include padding in average not yet implemented in PlaidML"
);
throw
std
::
runtime_error
(
"Include padding in average not yet implemented in PlaidML"
);
}
}
ngraph
::
CoordinateDiff
pad_above
;
ngraph
::
CoordinateDiff
pad_above
;
...
@@ -229,18 +236,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPoolBackprop>::operator()()
...
@@ -229,18 +236,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPoolBackprop>::operator()()
{
{
std
::
ostringstream
s
;
std
::
ostringstream
s
;
s
<<
"XI"
<<
i
-
2
;
s
<<
"XI"
<<
i
-
2
;
ret
.
add
(
builder
::
Input
{
static_cast
<
std
::
int64_t
>
(
forward_arg_shape
[
i
]),
s
.
str
()});
ret
.
add
(
builder
::
Input
{
static_cast
<
std
::
int64_t
>
(
forward_arg_shape
[
i
]),
s
.
str
()});
}
}
set_output
(
ret
.
add
(
cpf
.
Broadcast_Ones
())
set_output
(
ret
.
add
(
cpf
.
Broadcast_Ones
())
.
add
(
cpf
.
Count
())
.
add
(
cpf
.
Count
())
.
add
(
builder
::
Elementwise
{
"S"
,
"DO / Count"
})
.
add
(
builder
::
Elementwise
{
"S"
,
"DO / Count"
})
.
add
(
cpf
.
PoolContraction
())
.
add
(
cpf
.
PoolContraction
())
.
finalize
());
.
finalize
());
}
}
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
MaxPoolBackprop
>::
operator
()()
void
Impl
<
op
::
MaxPoolBackprop
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -299,14 +307,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::MaxPoolBackprop>::operator()()
...
@@ -299,14 +307,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::MaxPoolBackprop>::operator()()
.
add
(
cpf
.
PoolContraction
())
.
add
(
cpf
.
PoolContraction
())
.
add
(
cpf
.
PoolDerivContraction
())
.
add
(
cpf
.
PoolDerivContraction
())
.
finalize
());
.
finalize
());
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
AvgPool
>::
Registration
register_avg_pool
;
Impl
<
op
::
AvgPool
>::
Registration
register_avg_pool
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
MaxPool
>::
Registration
register_max_pool
;
Impl
<
op
::
MaxPool
>::
Registration
register_max_pool
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
AvgPoolBackprop
>::
Registration
Impl
<
op
::
AvgPoolBackprop
>::
Registration
register_avg_pool_backprop
;
register_avg_pool_backprop
;
Impl
<
op
::
MaxPoolBackprop
>::
Registration
register_max_pool_backprop
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
MaxPoolBackprop
>::
Registration
}
register_max_pool_backprop
;
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_reduce.cpp
View file @
61df6725
...
@@ -42,13 +42,10 @@ namespace ngraph
...
@@ -42,13 +42,10 @@ namespace ngraph
void
build_reduction
(
const
char
*
agg_op
);
void
build_reduction
(
const
char
*
agg_op
);
};
};
}
}
}
template
<
typename
O
>
template
<
typename
O
>
void
ngraph
::
runtime
::
plaidml
::
ReductionImpl
<
O
>::
build_reduction
(
const
char
*
agg_op
)
void
ReductionImpl
<
O
>::
build_reduction
(
const
char
*
agg_op
)
{
{
this
->
check_inputs
(
1
);
this
->
check_inputs
(
1
);
this
->
check_outputs
(
1
);
this
->
check_outputs
(
1
);
...
@@ -68,81 +65,86 @@ void ngraph::runtime::plaidml::ReductionImpl<O>::build_reduction(const char* agg
...
@@ -68,81 +65,86 @@ void ngraph::runtime::plaidml::ReductionImpl<O>::build_reduction(const char* agg
this
->
start_tile_function
()
this
->
start_tile_function
()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Input
{
this
->
op_input
(
0
),
"I"
}.
add_dims
(
"D"
,
1
,
in_dim_limit
+
1
))
.
add
(
builder
::
Input
{
this
->
op_input
(
0
),
"I"
}.
add_dims
(
.
add
(
builder
::
UnaryContraction
{
agg_op
}
"D"
,
1
,
in_dim_limit
+
1
))
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add
(
.
add_indices
(
builder
::
UnaryContraction
{
agg_op
}
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
out_idxs
.
size
();
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
out_idxs
.
size
();
++
idx
)
{
{
out
=
"d"
+
std
::
to_string
(
out_idxs
[
idx
]
+
1
);
out
=
"d"
+
std
::
to_string
(
out_idxs
[
idx
]
+
1
);
}
}
})
})
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
out_idxs
.
size
();
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
out_idxs
.
size
();
++
idx
)
{
{
out
=
"D"
+
std
::
to_string
(
out_idxs
[
idx
]
+
1
);
out
=
"D"
+
std
::
to_string
(
out_idxs
[
idx
]
+
1
);
}
}
}))
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
1
,
in_dim_limit
+
1
)))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
"d"
,
1
,
in_dim_limit
+
1
)))
.
finalize
());
.
finalize
());
}
}
template
<>
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
Max
>
struct
ParentImpl
<
op
::
Max
>
{
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ReductionImpl
<
ngraph
::
op
::
Max
>
;
using
Type
=
ReductionImpl
<
op
::
Max
>
;
};
};
template
<>
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
Min
>
struct
ParentImpl
<
op
::
Min
>
{
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ReductionImpl
<
ngraph
::
op
::
Min
>
;
using
Type
=
ReductionImpl
<
op
::
Min
>
;
};
};
template
<>
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
Product
>
struct
ParentImpl
<
op
::
Product
>
{
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ReductionImpl
<
ngraph
::
op
::
Product
>
;
using
Type
=
ReductionImpl
<
op
::
Product
>
;
};
};
template
<>
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
Reduce
>
struct
ParentImpl
<
op
::
Reduce
>
{
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ReductionImpl
<
ngraph
::
op
::
Reduce
>
;
using
Type
=
ReductionImpl
<
op
::
Reduce
>
;
};
};
template
<>
template
<>
struct
ngraph
::
runtime
::
plaidml
::
ParentImpl
<
ngraph
::
op
::
Sum
>
struct
ParentImpl
<
op
::
Sum
>
{
{
using
Type
=
ngraph
::
runtime
::
plaidml
::
ReductionImpl
<
ngraph
::
op
::
Sum
>
;
using
Type
=
ReductionImpl
<
op
::
Sum
>
;
};
};
// Max reduces a tensor, taking the maximum along the specified axes.
// Max reduces a tensor, taking the maximum along the specified axes.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Max
>::
operator
()()
void
Impl
<
op
::
Max
>::
operator
()()
{
{
build_reduction
(
">"
);
build_reduction
(
">"
);
}
}
// Min reduces a tensor, taking the minimum along the specified axes.
// Min reduces a tensor, taking the minimum along the specified axes.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Min
>::
operator
()()
void
Impl
<
op
::
Min
>::
operator
()()
{
{
build_reduction
(
"<"
);
build_reduction
(
"<"
);
}
}
// Min reduces a tensor, taking the product along the specified axes.
// Min reduces a tensor, taking the product along the specified axes.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Product
>::
operator
()()
void
Impl
<
op
::
Product
>::
operator
()()
{
{
build_reduction
(
"*"
);
build_reduction
(
"*"
);
}
}
// Reduce reduces a tensor with an arbitrary user-supplied reduction operation.
// Reduce reduces a tensor with an arbitrary user-supplied reduction operation.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Reduce
>::
operator
()()
void
Impl
<
op
::
Reduce
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -183,10 +185,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
...
@@ -183,10 +185,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
start_tile_function
()
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(
1
),
"I"
})
.
add
(
builder
::
Input
{
op_input
(
1
),
"I"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
add
(
.
set
(
builder
::
ContractionOutput
{
"O"
}
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"d"
,
0
,
agg_dim_limit
)
.
add_indices
(
"d"
,
0
,
agg_dim_limit
)
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
=
0
;
idx
<
agg_dim_limit
;
++
idx
)
for
(
auto
idx
=
0
;
idx
<
agg_dim_limit
;
++
idx
)
{
{
out
=
"1"
;
out
=
"1"
;
...
@@ -205,9 +210,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
...
@@ -205,9 +210,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
.
add_indices
([
&
](
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
for
(
std
::
size_t
idx
=
0
;
idx
<
input_shape
.
size
();
++
idx
)
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
input_shape
.
size
();
++
idx
)
{
{
if
(
!
op
().
get_reduction_axes
().
count
(
idx
))
if
(
!
op
().
get_reduction_axes
().
count
(
idx
))
{
{
...
@@ -215,9 +223,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
...
@@ -215,9 +223,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
}
}
}
}
})
})
.
add_dims
(
.
add_dims
([
&
](
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
for
(
std
::
size_t
idx
=
0
;
idx
<
input_shape
.
size
();
++
idx
)
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
input_shape
.
size
();
++
idx
)
{
{
if
(
!
op
().
get_reduction_axes
().
count
(
idx
))
if
(
!
op
().
get_reduction_axes
().
count
(
idx
))
{
{
...
@@ -225,8 +236,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
...
@@ -225,8 +236,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
}
}
}
}
}))
}))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
([
&
]
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
input_shape
.
size
();
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
input_shape
.
size
();
++
idx
)
{
{
std
::
size_t
cidx
=
0
;
std
::
size_t
cidx
=
0
;
...
@@ -244,20 +255,23 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
...
@@ -244,20 +255,23 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
}
}
set_output
(
result
);
set_output
(
result
);
}
}
// Sum reduces a tensor, summing the specified axes.
// Sum reduces a tensor, summing the specified axes.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sum
>::
operator
()()
void
Impl
<
op
::
Sum
>::
operator
()()
{
{
build_reduction
(
"+"
);
build_reduction
(
"+"
);
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Max
>::
Registration
register_max
;
Impl
<
op
::
Max
>::
Registration
register_max
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Min
>::
Registration
register_min
;
Impl
<
op
::
Min
>::
Registration
register_min
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Product
>::
Registration
register_product
;
Impl
<
op
::
Product
>::
Registration
register_product
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Reduce
>::
Registration
register_reduce
;
Impl
<
op
::
Reduce
>::
Registration
register_reduce
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sum
>::
Registration
register_sum
;
Impl
<
op
::
Sum
>::
Registration
register_sum
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_replace_slice.cpp
View file @
61df6725
...
@@ -19,10 +19,16 @@
...
@@ -19,10 +19,16 @@
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// ReplaceSlice replaces part of a tensor with another tensor.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ReplaceSlice
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// ReplaceSlice replaces part of a tensor with another tensor.
template
<>
void
Impl
<
op
::
ReplaceSlice
>::
operator
()()
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -43,11 +49,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()()
...
@@ -43,11 +49,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()()
.
add
(
builder
::
Input
{
op_input
(
0
),
"L"
}.
add_dims
(
"D"
,
0
,
shape
.
size
()))
.
add
(
builder
::
Input
{
op_input
(
0
),
"L"
}.
add_dims
(
"D"
,
0
,
shape
.
size
()))
.
add
(
builder
::
Input
{
op_input
(
1
),
"S"
}.
add_dims
(
"SD"
,
0
,
shape
.
size
()))
.
add
(
builder
::
Input
{
op_input
(
1
),
"S"
}.
add_dims
(
"SD"
,
0
,
shape
.
size
()))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
add
(
.
set
(
builder
::
ContractionOutput
{
"O"
}
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_dims
(
"D"
,
0
,
shape
.
size
())
.
add_dims
(
"D"
,
0
,
shape
.
size
())
.
add_indices
(
.
add_indices
([
&
]
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
shape
.
size
();
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
shape
.
size
();
++
idx
)
{
{
auto
stride
=
op
().
get_strides
()[
idx
];
auto
stride
=
op
().
get_strides
()[
idx
];
...
@@ -73,8 +81,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()()
...
@@ -73,8 +81,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()()
out
=
didx
.
str
();
out
=
didx
.
str
();
}
}
}))
}))
.
set
(
builder
::
ContractionInput
{
"S"
}.
add_indices
(
"d"
,
0
,
shape
.
size
()))
.
set
(
builder
::
ContractionInput
{
"S"
}.
add_indices
(
.
add_constraints
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
"d"
,
0
,
shape
.
size
()))
.
add_constraints
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
shape
.
size
();
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
shape
.
size
();
++
idx
)
{
{
out
=
"d"
+
std
::
to_string
(
idx
)
+
" < "
+
out
=
"d"
+
std
::
to_string
(
idx
)
+
" < "
+
...
@@ -84,9 +94,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()()
...
@@ -84,9 +94,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()()
})
})
.
set_default
(
"L"
))
.
set_default
(
"L"
))
.
finalize
());
.
finalize
());
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
ReplaceSlice
>::
Registration
register_replace_slice
;
Impl
<
op
::
ReplaceSlice
>::
Registration
register_replace_slice
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_reverse.cpp
View file @
61df6725
...
@@ -19,10 +19,16 @@
...
@@ -19,10 +19,16 @@
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Reverse reverses the selected axes within a tensor.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Reverse
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// Reverse reverses the selected axes within a tensor.
template
<>
void
Impl
<
op
::
Reverse
>::
operator
()()
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -35,8 +41,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reverse>::operator()()
...
@@ -35,8 +41,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reverse>::operator()()
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"d"
,
0
,
shape
.
size
())
.
add_indices
(
"d"
,
0
,
shape
.
size
())
.
add_dims
(
"D"
,
0
,
shape
.
size
()))
.
add_dims
(
"D"
,
0
,
shape
.
size
()))
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
(
.
set
(
builder
::
ContractionInput
{
"I"
}.
add_indices
([
&
]
(
[
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
shape
.
size
();
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
shape
.
size
();
++
idx
)
{
{
auto
sidx
=
std
::
to_string
(
idx
);
auto
sidx
=
std
::
to_string
(
idx
);
...
@@ -51,9 +57,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reverse>::operator()()
...
@@ -51,9 +57,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reverse>::operator()()
}
}
})))
})))
.
finalize
());
.
finalize
());
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Reverse
>::
Registration
register_reverse
;
Impl
<
op
::
Reverse
>::
Registration
register_reverse
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_slice.cpp
View file @
61df6725
...
@@ -18,10 +18,16 @@
...
@@ -18,10 +18,16 @@
#include "ngraph/op/slice.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Slice takes a sub-slice of a tensor.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Slice
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// Slice takes a sub-slice of a tensor.
template
<>
void
Impl
<
op
::
Slice
>::
operator
()()
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
NGRAPH_DEBUG
<<
"Slice: low: "
<<
op
().
get_lower_bounds
();
NGRAPH_DEBUG
<<
"Slice: low: "
<<
op
().
get_lower_bounds
();
...
@@ -33,17 +39,21 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Slice>::operator()()
...
@@ -33,17 +39,21 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Slice>::operator()()
start_tile_function
()
start_tile_function
()
.
add
(
builder
::
Input
{
op_input
(),
"I"
}.
add_dims
(
"ID"
,
0
,
dim_limit
))
.
add
(
builder
::
Input
{
op_input
(),
"I"
}.
add_dims
(
"ID"
,
0
,
dim_limit
))
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
UnaryContraction
{
"="
}
.
add
(
.
set
(
builder
::
ContractionOutput
{
"O"
}
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"O"
}
.
add_indices
(
"od"
,
0
,
dim_limit
)
.
add_indices
(
"od"
,
0
,
dim_limit
)
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
for
(
std
::
size_t
idx
=
0
;
idx
<
dim_limit
;
++
idx
)
{
{
std
::
ostringstream
s
;
std
::
ostringstream
s
;
std
::
size_t
stride
=
op
().
get_strides
()[
idx
];
std
::
size_t
stride
=
op
().
get_strides
()[
idx
];
std
::
ptrdiff_t
trim_count
=
std
::
ptrdiff_t
trim_count
=
op
().
get_lower_bounds
()[
idx
]
+
op
().
get_lower_bounds
()[
idx
]
+
(
shape
[
idx
]
-
op
().
get_upper_bounds
()[
idx
])
+
1
-
stride
;
(
shape
[
idx
]
-
op
().
get_upper_bounds
()[
idx
])
+
1
-
stride
;
if
((
stride
!=
1
)
&&
trim_count
)
if
((
stride
!=
1
)
&&
trim_count
)
{
{
s
<<
"("
;
s
<<
"("
;
...
@@ -96,9 +106,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Slice>::operator()()
...
@@ -96,9 +106,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Slice>::operator()()
}
}
})))
})))
.
finalize
());
.
finalize
());
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Slice
>::
Registration
register_slice
;
Impl
<
op
::
Slice
>::
Registration
register_slice
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_softmax.cpp
View file @
61df6725
...
@@ -19,10 +19,16 @@
...
@@ -19,10 +19,16 @@
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Softmax implements a standard ML softmax operation.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Softmax
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// Softmax implements a standard ML softmax operation.
template
<>
void
Impl
<
op
::
Softmax
>::
operator
()()
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
...
@@ -30,7 +36,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
...
@@ -30,7 +36,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
auto
dim_limit
=
shape
.
size
();
auto
dim_limit
=
shape
.
size
();
auto
f
=
start_tile_function
();
auto
f
=
start_tile_function
();
f
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
}.
add_dims
(
"D"
,
0
,
dim_limit
)).
add
(
builder
::
Output
{
"O"
});
f
.
add
(
builder
::
Input
{
op_input
(
0
),
"I"
}.
add_dims
(
"D"
,
0
,
dim_limit
))
.
add
(
builder
::
Output
{
"O"
});
bool
reorder_needed
=
false
;
bool
reorder_needed
=
false
;
bool
saw_element
=
false
;
bool
saw_element
=
false
;
...
@@ -71,7 +78,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
...
@@ -71,7 +78,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
{
{
f
.
add
(
builder
::
UnaryContraction
{
"="
}
f
.
add
(
builder
::
UnaryContraction
{
"="
}
.
set
(
builder
::
ContractionOutput
{
"RI"
}
.
set
(
builder
::
ContractionOutput
{
"RI"
}
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_dims
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
:
group_idxs
)
for
(
auto
idx
:
group_idxs
)
{
{
out
=
"D"
+
std
::
to_string
(
idx
);
out
=
"D"
+
std
::
to_string
(
idx
);
...
@@ -81,7 +89,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
...
@@ -81,7 +89,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
out
=
"D"
+
std
::
to_string
(
idx
);
out
=
"D"
+
std
::
to_string
(
idx
);
}
}
})
})
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
.
add_indices
([
&
](
std
::
back_insert_iterator
<
std
::
list
<
std
::
string
>>
out
)
{
for
(
auto
idx
:
group_idxs
)
for
(
auto
idx
:
group_idxs
)
{
{
out
=
"d"
+
std
::
to_string
(
idx
);
out
=
"d"
+
std
::
to_string
(
idx
);
...
@@ -117,7 +126,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
...
@@ -117,7 +126,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
{
{
// Take the softmax.
// Take the softmax.
std
::
ostringstream
softmax
;
std
::
ostringstream
softmax
;
softmax
<<
"builtin_softmax("
<<
input
<<
", "
<<
groups
<<
", "
<<
elements
<<
")"
;
softmax
<<
"builtin_softmax("
<<
input
<<
", "
<<
groups
<<
", "
<<
elements
<<
")"
;
f
.
add
(
builder
::
Elementwise
{
output
,
softmax
.
str
()});
f
.
add
(
builder
::
Elementwise
{
output
,
softmax
.
str
()});
}
}
...
@@ -159,9 +169,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
...
@@ -159,9 +169,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
}
}
set_output
(
f
.
finalize
());
set_output
(
f
.
finalize
());
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Softmax
>::
Registration
register_softmax
;
Impl
<
op
::
Softmax
>::
Registration
register_softmax
;
}
}
}
}
}
src/ngraph/runtime/plaidml/plaidml_ops_transcendental.cpp
View file @
61df6725
...
@@ -29,10 +29,16 @@
...
@@ -29,10 +29,16 @@
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// acos performs a simple elementwise arccos function.
namespace
ngraph
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Acos
>::
operator
()()
{
{
namespace
runtime
{
namespace
plaidml
{
// acos performs a simple elementwise arccos function.
template
<>
void
Impl
<
op
::
Acos
>::
operator
()()
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -40,12 +46,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Acos>::operator()()
...
@@ -40,12 +46,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Acos>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"acos(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"acos(I)"
})
.
finalize
());
.
finalize
());
}
}
// asin performs a simple elementwise arcsin function.
// asin performs a simple elementwise arcsin function.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Asin
>::
operator
()()
void
Impl
<
op
::
Asin
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -53,12 +59,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Asin>::operator()()
...
@@ -53,12 +59,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Asin>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"asin(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"asin(I)"
})
.
finalize
());
.
finalize
());
}
}
// atan performs a simple elementwise arctan function.
// atan performs a simple elementwise arctan function.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Atan
>::
operator
()()
void
Impl
<
op
::
Atan
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -66,12 +72,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Atan>::operator()()
...
@@ -66,12 +72,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Atan>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"atan(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"atan(I)"
})
.
finalize
());
.
finalize
());
}
}
// cos performs a simple elementwise cos function.
// cos performs a simple elementwise cos function.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Cos
>::
operator
()()
void
Impl
<
op
::
Cos
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -79,12 +85,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Cos>::operator()()
...
@@ -79,12 +85,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Cos>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"cos(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"cos(I)"
})
.
finalize
());
.
finalize
());
}
}
// cosh performs a simple elementwise hyperbolic cos function.
// cosh performs a simple elementwise hyperbolic cos function.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Cosh
>::
operator
()()
void
Impl
<
op
::
Cosh
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -92,12 +98,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Cosh>::operator()()
...
@@ -92,12 +98,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Cosh>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"cosh(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"cosh(I)"
})
.
finalize
());
.
finalize
());
}
}
// exp performs a simple elementwise natural exponential function.
// exp performs a simple elementwise natural exponential function.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Exp
>::
operator
()()
void
Impl
<
op
::
Exp
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -105,12 +111,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Exp>::operator()()
...
@@ -105,12 +111,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Exp>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"exp(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"exp(I)"
})
.
finalize
());
.
finalize
());
}
}
// log performs a simple elementwise natural logarithm function.
// log performs a simple elementwise natural logarithm function.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Log
>::
operator
()()
void
Impl
<
op
::
Log
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -118,12 +124,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Log>::operator()()
...
@@ -118,12 +124,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Log>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"log(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"log(I)"
})
.
finalize
());
.
finalize
());
}
}
// power performs a simple elementwise power function.
// power performs a simple elementwise power function.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Power
>::
operator
()()
void
Impl
<
op
::
Power
>::
operator
()()
{
{
check_inputs
(
2
);
check_inputs
(
2
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -132,12 +138,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Power>::operator()()
...
@@ -132,12 +138,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Power>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"pow(I, E)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"pow(I, E)"
})
.
finalize
());
.
finalize
());
}
}
// sin performs a simple elementwise sin function.
// sin performs a simple elementwise sin function.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sin
>::
operator
()()
void
Impl
<
op
::
Sin
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -145,12 +151,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sin>::operator()()
...
@@ -145,12 +151,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sin>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"sin(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"sin(I)"
})
.
finalize
());
.
finalize
());
}
}
// sinh performs a simple elementwise hyperbolic sin function.
// sinh performs a simple elementwise hyperbolic sin function.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sinh
>::
operator
()()
void
Impl
<
op
::
Sinh
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -158,12 +164,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sinh>::operator()()
...
@@ -158,12 +164,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sinh>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"sinh(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"sinh(I)"
})
.
finalize
());
.
finalize
());
}
}
// sqrt performs a simple elementwise square root function.
// sqrt performs a simple elementwise square root function.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sqrt
>::
operator
()()
void
Impl
<
op
::
Sqrt
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -171,12 +177,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sqrt>::operator()()
...
@@ -171,12 +177,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sqrt>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"sqrt(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"sqrt(I)"
})
.
finalize
());
.
finalize
());
}
}
// tan performs a simple elementwise tangent function.
// tan performs a simple elementwise tangent function.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Tan
>::
operator
()()
void
Impl
<
op
::
Tan
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -184,12 +190,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Tan>::operator()()
...
@@ -184,12 +190,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Tan>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"tan(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"tan(I)"
})
.
finalize
());
.
finalize
());
}
}
// tanh performs a simple elementwise hyperbolic tangent function.
// tanh performs a simple elementwise hyperbolic tangent function.
template
<>
template
<>
void
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Tanh
>::
operator
()()
void
Impl
<
op
::
Tanh
>::
operator
()()
{
{
check_inputs
(
1
);
check_inputs
(
1
);
check_outputs
(
1
);
check_outputs
(
1
);
set_output
(
start_tile_function
()
set_output
(
start_tile_function
()
...
@@ -197,21 +203,24 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Tanh>::operator()()
...
@@ -197,21 +203,24 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Tanh>::operator()()
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Output
{
"O"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"tanh(I)"
})
.
add
(
builder
::
Elementwise
{
"O"
,
"tanh(I)"
})
.
finalize
());
.
finalize
());
}
}
namespace
namespace
{
{
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Acos
>::
Registration
register_acos
;
Impl
<
op
::
Acos
>::
Registration
register_acos
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Asin
>::
Registration
register_asin
;
Impl
<
op
::
Asin
>::
Registration
register_asin
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Atan
>::
Registration
register_atan
;
Impl
<
op
::
Atan
>::
Registration
register_atan
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Cos
>::
Registration
register_cos
;
Impl
<
op
::
Cos
>::
Registration
register_cos
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Cosh
>::
Registration
register_cosh
;
Impl
<
op
::
Cosh
>::
Registration
register_cosh
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Exp
>::
Registration
register_exp
;
Impl
<
op
::
Exp
>::
Registration
register_exp
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Log
>::
Registration
register_log
;
Impl
<
op
::
Log
>::
Registration
register_log
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Power
>::
Registration
register_power
;
Impl
<
op
::
Power
>::
Registration
register_power
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sin
>::
Registration
register_sin
;
Impl
<
op
::
Sin
>::
Registration
register_sin
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sinh
>::
Registration
register_sinh
;
Impl
<
op
::
Sinh
>::
Registration
register_sinh
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Sqrt
>::
Registration
register_sqrt
;
Impl
<
op
::
Sqrt
>::
Registration
register_sqrt
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Tan
>::
Registration
register_tan
;
Impl
<
op
::
Tan
>::
Registration
register_tan
;
ngraph
::
runtime
::
plaidml
::
Impl
<
ngraph
::
op
::
Tanh
>::
Registration
register_tanh
;
Impl
<
op
::
Tanh
>::
Registration
register_tanh
;
}
}
}
}
}
src/ngraph/runtime/plaidml/unit_test.manifest
View file @
61df6725
...
@@ -38,6 +38,8 @@ topk_2d_max_one # No plans to implement TopK
...
@@ -38,6 +38,8 @@ topk_2d_max_one # No plans to implement TopK
topk_2d_min_all # No plans to implement TopK
topk_2d_min_all # No plans to implement TopK
topk_2d_min_partial # No plans to implement TopK
topk_2d_min_partial # No plans to implement TopK
topk_2d_min_one # No plans to implement TopK
topk_2d_min_one # No plans to implement TopK
topk_int64 # No plans to implement TopK
topk_5d_max_partial # No plans to implement TopK
# Tests that PlaidML might be able to run at some point.
# Tests that PlaidML might be able to run at some point.
backwards_maxpool_n2_c1_hw5_3x3_str2_max_pad1x2_2x3
backwards_maxpool_n2_c1_hw5_3x3_str2_max_pad1x2_2x3
...
@@ -84,3 +86,5 @@ sum_3d_eliminate_zero_dim # Empty dims apparently should produce shape
...
@@ -84,3 +86,5 @@ sum_3d_eliminate_zero_dim # Empty dims apparently should produce shape
dot_0_0 # Empty dims apparently should produce shaped 0s
dot_0_0 # Empty dims apparently should produce shaped 0s
dot_matrix_2x0_0x2 # Empty dims apparently should produce shaped 0s
dot_matrix_2x0_0x2 # Empty dims apparently should produce shaped 0s
dot_2x0_0 # Empty dims apparently should produce shaped 0s
dot_2x0_0 # Empty dims apparently should produce shaped 0s
numeric_float_nan
numeric_double_nan
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment