Commit 1f76a2a7 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

Merge branch 'master' into cpu_layout2

parents b2fdb1f8 ea29c6e3
# API Changes # API Changes
## Changes to convolution and pooling ops
* Backprop ops have been added for convolution ops.
* The convolution and pooling ops have had several methods/fields renamed, to reflect a shift
in terminology from "images" to "data". Generally this just means that you will have to
`s/image_batch/data_batch/` and `s/image_dilation_strides/data_dilation_strides/`.
* The following functions have been removed:
+ `AvgPool`: `get_channel_count get_input_image_physical_shape get_input_image_virtual_shape get_output_image_shape get_batch_size get_image_dimension_count`
+ `MaxPool`: `get_channel_count get_input_image_shape get_output_image_shape get_batch_size get_image_dimension_count`
+ `Convolution`: `get_input_channel_count get_output_channel_count get_input_image_physical_shape get_input_image_virtual_shape get_output_image_shape get_window_physical_shape get_window_virtual_shape get_batch_size get_image_dimension_count`
All of the above information can be inferred from the shapes and parameters of the op.
## Negative convolution padding ## Negative convolution padding
`Convolution` now allows negative padding. This means that the `padding_below` and `padding_above` `Convolution` now allows negative padding. This means that the `padding_below` and `padding_above`
......
...@@ -170,6 +170,8 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND ...@@ -170,6 +170,8 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
runtime/cpu/cpu_external_function.cpp runtime/cpu/cpu_external_function.cpp
runtime/cpu/cpu_tensor_view_wrapper.cpp runtime/cpu/cpu_tensor_view_wrapper.cpp
runtime/cpu/cpu_layout_descriptor.cpp runtime/cpu/cpu_layout_descriptor.cpp
runtime/cpu/ops/matmul_bias.cpp
runtime/cpu/pass/cpu_fusion.cpp
runtime/cpu/pass/cpu_layout.cpp runtime/cpu/pass/cpu_layout.cpp
) )
# LLVM binary builds are typically built without RTTI # LLVM binary builds are typically built without RTTI
......
...@@ -152,3 +152,19 @@ std::list<shared_ptr<Node>> Function::get_ops() const ...@@ -152,3 +152,19 @@ std::list<shared_ptr<Node>> Function::get_ops() const
}); });
return ops; return ops;
} }
void Function::replace_output_op(std::shared_ptr<Node> old, std::shared_ptr<Node> repl)
{
auto it = std::find(begin(m_results), end(m_results), old);
if (it != end(m_results))
{
NGRAPH_DEBUG << "Replacing output " << old->get_name() << " w/ " << repl->get_name();
*it = repl;
}
}
void Function::replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> repl)
{
replace_output_op(old, repl);
ngraph::replace_node(old, repl, true);
}
...@@ -78,6 +78,10 @@ namespace ngraph ...@@ -78,6 +78,10 @@ namespace ngraph
size_t get_instance_id() { return m_instance_id; } size_t get_instance_id() { return m_instance_id; }
size_t get_temporary_pool_size(); size_t get_temporary_pool_size();
void set_temporary_pool_size(size_t); void set_temporary_pool_size(size_t);
//updates old w/ repl in m_results list
void replace_output_op(std::shared_ptr<Node> old, std::shared_ptr<Node> repl);
//updates graph and m_results list
void replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> repl);
protected: protected:
Nodes m_results; Nodes m_results;
......
...@@ -105,12 +105,15 @@ void ngraph::free_nodes(shared_ptr<Function> p) ...@@ -105,12 +105,15 @@ void ngraph::free_nodes(shared_ptr<Function> p)
} }
} }
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement) void ngraph::replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement,
bool replace_output)
{ {
if (target->is_output()) //this restriction can be lifted when we find an use case for it if (target->is_output() && !replace_output)
{ {
return; return;
} }
//fix input/output descriptors //fix input/output descriptors
assert(target->get_outputs().size() == replacement->get_outputs().size()); assert(target->get_outputs().size() == replacement->get_outputs().size());
for (size_t i = 0; i < target->get_outputs().size(); i++) for (size_t i = 0; i < target->get_outputs().size(); i++)
......
...@@ -42,7 +42,9 @@ namespace ngraph ...@@ -42,7 +42,9 @@ namespace ngraph
void free_nodes(std::shared_ptr<Function>); void free_nodes(std::shared_ptr<Function>);
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement); void replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement,
bool replace_output = false);
void replace_node_users_arguments(std::shared_ptr<Node> target, void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement); std::shared_ptr<Node> replacement);
......
...@@ -110,12 +110,12 @@ namespace nervana ...@@ -110,12 +110,12 @@ namespace nervana
__PRETTY_FUNCTION__) \ __PRETTY_FUNCTION__) \
.stream() .stream()
//#define NGRAPH_DEBUG \ // #define NGRAPH_DEBUG \
// nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_DEBUG, \ // nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_DEBUG, \
// nervana::get_file_name(__FILE__), \ // nervana::get_file_name(__FILE__), \
// __LINE__, \ // __LINE__, \
// __PRETTY_FUNCTION__) \ // __PRETTY_FUNCTION__) \
// .stream() // .stream()
#define NGRAPH_DEBUG nervana::get_nil_stream() #define NGRAPH_DEBUG nervana::get_nil_stream()
} }
...@@ -37,109 +37,115 @@ op::AvgPool::AvgPool(const std::shared_ptr<Node>& arg, ...@@ -37,109 +37,115 @@ op::AvgPool::AvgPool(const std::shared_ptr<Node>& arg,
if (arg_shape.size() < 3) if (arg_shape.size() < 3)
{ {
throw ngraph_error( throw ngraph_error(
"Average-pool image batch input must have rank of at least 3 (one batch axis, one " "Average-pool data batch input must have rank of at least 3 (one batch axis, one "
"channel axis, at least one image dimension)."); "channel axis, at least one spatial dimension).");
} }
m_batch_size = arg_shape[0]; size_t batch_size = arg_shape[0];
if (m_batch_size == 0) if (batch_size == 0)
{ {
throw ngraph_error("Average-pool image batch size is zero."); throw ngraph_error("Average-pool data batch size is zero.");
} }
m_channel_count = arg_shape[1]; size_t channel_count = arg_shape[1];
if (m_channel_count == 0) if (channel_count == 0)
{ {
throw ngraph_error("Average-pool requires at least one image depth channel."); throw ngraph_error("Average-pool requires at least one feature channel.");
} }
m_image_dimension_count = arg_shape.size() - 2; size_t spatial_dimension_count = arg_shape.size() - 2;
// //
// Make sure window shape, window movement strides, and have same rank as Di. // Make sure window shape, window movement strides, and have same rank as Di.
// //
if (m_window_shape.size() != m_image_dimension_count) if (window_shape.size() != spatial_dimension_count)
{ {
throw ngraph_error( throw ngraph_error(
"Average-pool window shape rank does not match number of image dimensions."); "Average-pool window shape rank does not match number of spatial dimensions.");
} }
if (m_window_movement_strides.size() != m_image_dimension_count) if (window_movement_strides.size() != spatial_dimension_count)
{ {
throw ngraph_error( throw ngraph_error(
"Average-pool window movement stride rank does not match number of image dimensions."); "Average-pool window movement stride rank does not match number of spatial "
"dimensions.");
} }
if (m_padding_below.size() != m_image_dimension_count) if (padding_below.size() != spatial_dimension_count)
{ {
throw ngraph_error( throw ngraph_error(
"Average-pool below-padding rank does not match number of image dimensions."); "Average-pool below-padding rank does not match number of spatial dimensions.");
} }
if (m_padding_above.size() != m_image_dimension_count) if (padding_above.size() != spatial_dimension_count)
{ {
throw ngraph_error( throw ngraph_error(
"Average-pool above-padding rank does not match number of image dimensions."); "Average-pool above-padding rank does not match number of spatial dimensions.");
} }
// //
// Extract input image shape Di and make sure all dimensions are larger than 0. // Extract input item shape Di and make sure all dimensions are larger than 0.
// //
for (size_t i = 0; i < m_image_dimension_count; i++) Shape input_item_virtual_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
size_t dim_size = arg_shape[1 + 1 + i]; size_t dim_size = arg_shape[1 + 1 + i];
m_input_image_physical_shape.push_back(dim_size); size_t virtual_dim_size = padding_below[i] + dim_size + padding_above[i];
m_input_image_virtual_shape.push_back(padding_below[i] + dim_size + padding_above[i]); input_item_virtual_shape.push_back(virtual_dim_size);
if (m_input_image_virtual_shape[i] == 0) if (virtual_dim_size == 0)
{ {
throw ngraph_error("Average-pool input image dimension is zero even after padding."); throw ngraph_error("Average-pool input spatial dimension is zero even after padding.");
} }
} }
// //
// Make sure window shape dimensions are all larger than 0. // Make sure window shape dimensions are all larger than 0.
// //
for (size_t i = 0; i < m_image_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (m_window_shape[i] == 0) if (window_shape[i] == 0)
{ {
throw ngraph_error("Average-pool window shape has a zero-length axis."); throw ngraph_error("Average-pool window shape has a zero-length axis.");
} }
} }
// //
// Make the max pooling window fits within the image dimensions. // Make the max pooling window fits within the spatial dimensions.
// //
for (size_t i = 0; i < m_image_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (m_window_shape[i] > m_input_image_virtual_shape[i]) if (window_shape[i] > input_item_virtual_shape[i])
{ {
throw ngraph_error( throw ngraph_error(
"Average-pool window shape is larger than the image even after padding."); "Average-pool window shape is larger than the spatial dimensions even after "
"padding.");
} }
} }
// //
// Compute image output shape Do, checking at the same time that all window movement strides are larger than 0. // Compute output item shape Do, checking at the same time that all window movement strides are larger than 0.
// //
for (size_t i = 0; i < m_image_dimension_count; i++) Shape output_item_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (m_window_movement_strides[i] == 0) if (window_movement_strides[i] == 0)
{ {
throw ngraph_error("Average-pool window axis movement stride is zero."); throw ngraph_error("Average-pool window axis movement stride is zero.");
} }
m_output_image_shape.push_back(ceil_div( output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - window_shape[i] + 1,
m_input_image_virtual_shape[i] - m_window_shape[i] + 1, m_window_movement_strides[i])); window_movement_strides[i]));
} }
// //
// Construct result shape: NCDo. // Construct result shape: NCDo.
// //
Shape result_shape(1 + 1 + m_image_dimension_count); Shape result_shape(1 + 1 + spatial_dimension_count);
result_shape[0] = m_batch_size; result_shape[0] = batch_size;
result_shape[1] = m_channel_count; result_shape[1] = channel_count;
std::copy(m_output_image_shape.begin(), m_output_image_shape.end(), result_shape.begin() + 2); std::copy(output_item_shape.begin(), output_item_shape.end(), result_shape.begin() + 2);
set_value_type_checked(get_input_element_type(0), result_shape); set_value_type_checked(get_input_element_type(0), result_shape);
} }
...@@ -148,7 +154,7 @@ static Shape default_padding(const std::shared_ptr<Node>& arg) ...@@ -148,7 +154,7 @@ static Shape default_padding(const std::shared_ptr<Node>& arg)
{ {
if (arg->get_outputs().size() != 1) if (arg->get_outputs().size() != 1)
{ {
throw ngraph_error("Average-pool image batch argument must have exactly one output"); throw ngraph_error("Average-pool data batch argument must have exactly one output");
} }
auto& arg_shape = arg->get_outputs().at(0).get_shape(); auto& arg_shape = arg->get_outputs().at(0).get_shape();
...@@ -156,8 +162,8 @@ static Shape default_padding(const std::shared_ptr<Node>& arg) ...@@ -156,8 +162,8 @@ static Shape default_padding(const std::shared_ptr<Node>& arg)
{ {
// For consistency we should throw the same error message here that we throw in the constructor. // For consistency we should throw the same error message here that we throw in the constructor.
throw ngraph_error( throw ngraph_error(
"Average-pool image batch input must have rank of at least 3 (one batch axis, one " "Average-pool data batch input must have rank of at least 3 (one batch axis, one "
"channel axis, at least one image dimension)."); "channel axis, at least one spatial dimension).");
} }
return Shape(arg_shape.size() - 2, 0); return Shape(arg_shape.size() - 2, 0);
} }
...@@ -174,7 +180,7 @@ static Strides default_strides(const std::shared_ptr<Node>& arg) ...@@ -174,7 +180,7 @@ static Strides default_strides(const std::shared_ptr<Node>& arg)
{ {
if (arg->get_outputs().size() != 1) if (arg->get_outputs().size() != 1)
{ {
throw ngraph_error("Average-pool image batch argument must have exactly one output"); throw ngraph_error("Average-pool data batch argument must have exactly one output");
} }
auto& arg_shape = arg->get_outputs().at(0).get_shape(); auto& arg_shape = arg->get_outputs().at(0).get_shape();
...@@ -182,8 +188,8 @@ static Strides default_strides(const std::shared_ptr<Node>& arg) ...@@ -182,8 +188,8 @@ static Strides default_strides(const std::shared_ptr<Node>& arg)
{ {
// For consistency we should throw the same error message here that we throw in the constructor. // For consistency we should throw the same error message here that we throw in the constructor.
throw ngraph_error( throw ngraph_error(
"Average-pool image batch input must have rank of at least 3 (one batch axis, one " "Average-pool data batch input must have rank of at least 3 (one batch axis, one "
"channel axis, at least one image dimension)."); "channel axis, at least one spatial dimension).");
} }
return Strides(arg_shape.size() - 2, 1); return Strides(arg_shape.size() - 2, 1);
} }
...@@ -203,13 +209,6 @@ bool op::AvgPool::is_functionally_identical(const Node& other) const ...@@ -203,13 +209,6 @@ bool op::AvgPool::is_functionally_identical(const Node& other) const
rc &= m_window_movement_strides == rhs.m_window_movement_strides; rc &= m_window_movement_strides == rhs.m_window_movement_strides;
rc &= m_padding_below == rhs.m_padding_below; rc &= m_padding_below == rhs.m_padding_below;
rc &= m_padding_above == rhs.m_padding_above; rc &= m_padding_above == rhs.m_padding_above;
rc &= m_window_movement_strides == rhs.m_window_movement_strides;
rc &= m_channel_count == rhs.m_channel_count;
rc &= m_input_image_physical_shape == rhs.m_input_image_physical_shape;
rc &= m_input_image_virtual_shape == rhs.m_input_image_virtual_shape;
rc &= m_output_image_shape == rhs.m_output_image_shape;
rc &= m_batch_size == rhs.m_batch_size;
rc &= m_image_dimension_count == rhs.m_image_dimension_count;
} }
else else
{ {
......
...@@ -22,8 +22,9 @@ namespace ngraph ...@@ -22,8 +22,9 @@ namespace ngraph
{ {
/// \brief Batched average pooling operation, with optional padding and window stride. /// \brief Batched average pooling operation, with optional padding and window stride.
/// ///
/// Average pooling takes as its input an image batch tensor of shape \f$(N,C,d_1,\dots,d_n)\f$ where \f$n > 0\f$, every \f$d_i > 0\f$, and where \f$N\f$ is /// Average pooling takes as its input an data batch tensor of shape \f$(N,C,d_1,\dots,d_n)\f$ where \f$n > 0\f$, every \f$d_i > 0\f$, and where \f$N\f$ is
/// the batch size, and \f$C > 0\f$ is the number of channels (sometimes called features). It also takes four parameters: /// the batch size, and \f$C > 0\f$ is the number of channels (sometimes called features). The dimensions \f$(d_1,\dots,d_n)\f$ correspond to the shape of
/// an \f$n\f$-dimensional data item in a batch. For example, where \f$n=2\f$, the data may represent a two-dimensional image. It also takes four parameters:
/// ///
/// 1. <i>(the window shape)</i> a size vector \f$(w_1,\dots,w_n)\f$ where every \f$w_i \le d_i\f$; and /// 1. <i>(the window shape)</i> a size vector \f$(w_1,\dots,w_n)\f$ where every \f$w_i \le d_i\f$; and
/// 2. <i>(the window movement strides, optional)</i> a vector of positive integers \f$(s_1,\dots,s_n)\f$. /// 2. <i>(the window movement strides, optional)</i> a vector of positive integers \f$(s_1,\dots,s_n)\f$.
...@@ -32,7 +33,7 @@ namespace ngraph ...@@ -32,7 +33,7 @@ namespace ngraph
/// ///
/// The output has the shape \f$(N,C,d'_1,\dots,d'_n)\f$, where \f$d'_n = \lceil \frac{p_i + d_i + q_i - w_i + 1}{s_i} \rceil\f$. /// The output has the shape \f$(N,C,d'_1,\dots,d'_n)\f$, where \f$d'_n = \lceil \frac{p_i + d_i + q_i - w_i + 1}{s_i} \rceil\f$.
/// ///
/// *In the absence of padding*, given an input image batch tensor \f$T_\textit{in}\f$, the output tensor is defined by the equation /// *In the absence of padding*, given an input data batch tensor \f$T_\textit{in}\f$, the output tensor is defined by the equation
/// ///
/// \f[ /// \f[
/// T_\textit{out}[a,c,i_1,\dots,i_n] = \frac{\sum_{j_1 = s_1 i_1, \dots, j_n = s_n i_n}^{j_1 = s_1 i_1 + w_1 - 1, \dots, j_n = s_n i_n + w_n - 1} T_\textit{in}[a,c,j_1,\dots,j_n]}{\prod_{i=1}^n{w_n}} /// T_\textit{out}[a,c,i_1,\dots,i_n] = \frac{\sum_{j_1 = s_1 i_1, \dots, j_n = s_n i_n}^{j_1 = s_1 i_1 + w_1 - 1, \dots, j_n = s_n i_n + w_n - 1} T_\textit{in}[a,c,j_1,\dots,j_n]}{\prod_{i=1}^n{w_n}}
...@@ -65,7 +66,7 @@ namespace ngraph ...@@ -65,7 +66,7 @@ namespace ngraph
public: public:
/// \brief Constructs a batched average pooling operation. /// \brief Constructs a batched average pooling operation.
/// ///
/// \param arg The node producing the input image batch tensor. /// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape. /// \param window_shape The window shape.
/// \param window_movement_strides The window movement strides. /// \param window_movement_strides The window movement strides.
/// \param padding_below The below-padding shape. /// \param padding_below The below-padding shape.
...@@ -78,7 +79,7 @@ namespace ngraph ...@@ -78,7 +79,7 @@ namespace ngraph
/// \brief Constructs a batched, unpadded average pooling operation (i.e., all padding shapes are set to 0). /// \brief Constructs a batched, unpadded average pooling operation (i.e., all padding shapes are set to 0).
/// ///
/// \param arg The node producing the input image batch tensor. /// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape. /// \param window_shape The window shape.
/// \param window_movement_strides The window movement strides. /// \param window_movement_strides The window movement strides.
AvgPool(const std::shared_ptr<Node>& arg, AvgPool(const std::shared_ptr<Node>& arg,
...@@ -87,7 +88,7 @@ namespace ngraph ...@@ -87,7 +88,7 @@ namespace ngraph
/// \brief Constructs an unstrided batched convolution operation (i.e., all window movement strides are 1 and all padding shapes are set to 0). /// \brief Constructs an unstrided batched convolution operation (i.e., all window movement strides are 1 and all padding shapes are set to 0).
/// ///
/// \param arg The node producing the input image batch tensor. /// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape. /// \param window_shape The window shape.
AvgPool(const std::shared_ptr<Node>& arg, const Shape& window_shape); AvgPool(const std::shared_ptr<Node>& arg, const Shape& window_shape);
...@@ -102,6 +103,7 @@ namespace ngraph ...@@ -102,6 +103,7 @@ namespace ngraph
m_padding_below, m_padding_below,
m_padding_above); m_padding_above);
} }
bool is_functionally_identical(const Node&) const override;
/// \return The window shape. /// \return The window shape.
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
...@@ -111,38 +113,11 @@ namespace ngraph ...@@ -111,38 +113,11 @@ namespace ngraph
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
/// \return The above-padding shape. /// \return The above-padding shape.
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
/// \return The number of image channels.
size_t get_channel_count() const { return m_channel_count; }
/// \return The input image physical shape, not including padding.
const Shape& get_input_image_physical_shape() const
{
return m_input_image_physical_shape;
}
/// \return The input image virtual shape, including padding.
const Shape& get_input_image_virtual_shape() const
{
return m_input_image_virtual_shape;
}
/// \return The output image shape.
const Shape& get_output_image_shape() const { return m_output_image_shape; }
/// \return The batch size.
size_t get_batch_size() const { return m_batch_size; }
/// \return The number of image dimensions.
size_t get_image_dimension_count() const { return m_image_dimension_count; }
bool is_functionally_identical(const Node&) const override;
protected: protected:
Shape m_window_shape; Shape m_window_shape;
Strides m_window_movement_strides; Strides m_window_movement_strides;
Shape m_padding_below; Shape m_padding_below;
Shape m_padding_above; Shape m_padding_above;
size_t m_channel_count;
Shape m_input_image_physical_shape;
Shape m_input_image_virtual_shape;
Shape m_output_image_shape;
size_t m_batch_size;
size_t m_image_dimension_count;
}; };
} }
} }
This diff is collapsed.
This diff is collapsed.
...@@ -38,93 +38,98 @@ op::MaxPool::MaxPool(const std::shared_ptr<Node>& arg, ...@@ -38,93 +38,98 @@ op::MaxPool::MaxPool(const std::shared_ptr<Node>& arg,
if (arg_shape.size() < 3) if (arg_shape.size() < 3)
{ {
throw ngraph_error( throw ngraph_error(
"Max pool image batch input must have rank of at least 3 (one batch axis, one " "Max pool data batch input must have rank of at least 3 (one batch axis, one "
"channel axis, at least one image dimension)."); "channel axis, at least one spatial dimension).");
} }
m_batch_size = arg_shape[0]; size_t batch_size = arg_shape[0];
if (m_batch_size == 0) if (batch_size == 0)
{ {
throw ngraph_error("Max pool image batch size is zero."); throw ngraph_error("Max pool data batch size is zero.");
} }
m_channel_count = arg_shape[1]; size_t channel_count = arg_shape[1];
if (m_channel_count == 0) if (channel_count == 0)
{ {
throw ngraph_error("Max pool requires at least one image depth channel."); throw ngraph_error("Max pool requires at least one feature channel.");
} }
m_image_dimension_count = arg_shape.size() - 2; size_t spatial_dimension_count = arg_shape.size() - 2;
// //
// Make sure window shape and movement strides have same rank as Di. // Make sure window shape and movement strides have same rank as Di.
// //
if (m_window_shape.size() != m_image_dimension_count) if (window_shape.size() != spatial_dimension_count)
{ {
throw ngraph_error("Max pool window shape rank does not match number of image dimensions."); throw ngraph_error(
"Max pool window shape rank does not match number of spatial dimensions.");
} }
if (m_window_movement_strides.size() != m_image_dimension_count) if (window_movement_strides.size() != spatial_dimension_count)
{ {
throw ngraph_error( throw ngraph_error(
"Max pool window movement stride rank does not match number of image dimensions."); "Max pool window movement stride rank does not match number of spatial dimensions.");
} }
// //
// Extract input image shape Di and make sure all dimensions are larger than 0. // Extract input item shape Di and make sure all dimensions are larger than 0.
// //
for (size_t i = 0; i < m_image_dimension_count; i++) Shape input_spatial_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
m_input_image_shape.push_back(arg_shape[1 + 1 + i]); input_spatial_shape.push_back(arg_shape[1 + 1 + i]);
if (m_input_image_shape[i] == 0) if (input_spatial_shape[i] == 0)
{ {
throw ngraph_error("Max pool input image dimension is zero."); throw ngraph_error("Max pool input spatial dimension is zero.");
} }
} }
// //
// Make sure window shape dimensions are all larger than 0. // Make sure window shape dimensions are all larger than 0.
// //
for (size_t i = 0; i < m_image_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (m_window_shape[i] == 0) if (window_shape[i] == 0)
{ {
throw ngraph_error("Max pool window shape has a zero-length axis."); throw ngraph_error("Max pool window shape has a zero-length axis.");
} }
} }
// //
// Make the max pooling window fits within the image dimensions. // Make the max pooling window fits within the spatial dimensions.
// //
for (size_t i = 0; i < m_image_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (m_window_shape[i] > m_input_image_shape[i]) if (window_shape[i] > input_spatial_shape[i])
{ {
throw ngraph_error("Max pool window shape is larger than the image."); throw ngraph_error("Max pool window shape is larger than the spatial dimensions.");
} }
} }
// //
// Compute image output shape Do, checking at the same time that all window movement strides are larger than 0. // Compute output item shape Do, checking at the same time that all window movement strides are larger than 0.
// //
for (size_t i = 0; i < m_image_dimension_count; i++) Shape output_spatial_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (m_window_movement_strides[i] == 0) if (window_movement_strides[i] == 0)
{ {
throw ngraph_error("Max pool window axis movement stride is zero."); throw ngraph_error("Max pool window axis movement stride is zero.");
} }
m_output_image_shape.push_back( output_spatial_shape.push_back(
ceil_div(m_input_image_shape[i] - m_window_shape[i] + 1, m_window_movement_strides[i])); ceil_div(input_spatial_shape[i] - window_shape[i] + 1, window_movement_strides[i]));
} }
// //
// Construct result shape: NCDo. // Construct result shape: NCDo.
// //
Shape result_shape(1 + 1 + m_image_dimension_count); Shape result_shape(1 + 1 + spatial_dimension_count);
result_shape[0] = m_batch_size; result_shape[0] = batch_size;
result_shape[1] = m_channel_count; result_shape[1] = channel_count;
std::copy(m_output_image_shape.begin(), m_output_image_shape.end(), result_shape.begin() + 2); std::copy(output_spatial_shape.begin(), output_spatial_shape.end(), result_shape.begin() + 2);
set_value_type_checked(get_inputs().at(0).get_element_type(), result_shape); set_value_type_checked(get_inputs().at(0).get_element_type(), result_shape);
} }
...@@ -133,7 +138,7 @@ static Strides default_strides(const std::shared_ptr<Node>& arg) ...@@ -133,7 +138,7 @@ static Strides default_strides(const std::shared_ptr<Node>& arg)
{ {
if (arg->get_outputs().size() != 1) if (arg->get_outputs().size() != 1)
{ {
throw ngraph_error("Max pool image batch argument must have exactly one output"); throw ngraph_error("Max pool data batch argument must have exactly one output");
} }
auto& arg_shape = arg->get_outputs().at(0).get_shape(); auto& arg_shape = arg->get_outputs().at(0).get_shape();
...@@ -141,8 +146,8 @@ static Strides default_strides(const std::shared_ptr<Node>& arg) ...@@ -141,8 +146,8 @@ static Strides default_strides(const std::shared_ptr<Node>& arg)
{ {
// For consistency we should throw the same error message here that we throw in the constructor. // For consistency we should throw the same error message here that we throw in the constructor.
throw ngraph_error( throw ngraph_error(
"Max pool image batch input must have rank of at least 3 (one batch axis, one " "Max pool data batch input must have rank of at least 3 (one batch axis, one "
"channel axis, at least one image dimension)."); "channel axis, at least one spatial dimension).");
} }
return Strides(arg_shape.size() - 2, 1); return Strides(arg_shape.size() - 2, 1);
} }
...@@ -160,11 +165,6 @@ bool op::MaxPool::is_functionally_identical(const Node& other) const ...@@ -160,11 +165,6 @@ bool op::MaxPool::is_functionally_identical(const Node& other) const
const MaxPool& rhs = dynamic_cast<const MaxPool&>(other); const MaxPool& rhs = dynamic_cast<const MaxPool&>(other);
rc &= m_window_shape == rhs.m_window_shape; rc &= m_window_shape == rhs.m_window_shape;
rc &= m_window_movement_strides == rhs.m_window_movement_strides; rc &= m_window_movement_strides == rhs.m_window_movement_strides;
rc &= m_channel_count == rhs.m_channel_count;
rc &= m_input_image_shape == rhs.m_input_image_shape;
rc &= m_output_image_shape == rhs.m_output_image_shape;
rc &= m_batch_size == rhs.m_batch_size;
rc &= m_image_dimension_count == rhs.m_image_dimension_count;
} }
else else
{ {
......
...@@ -22,15 +22,16 @@ namespace ngraph ...@@ -22,15 +22,16 @@ namespace ngraph
{ {
/// \brief Batched max pooling operation, with optional window stride. /// \brief Batched max pooling operation, with optional window stride.
/// ///
/// Max pooling takes as its input an image batch tensor of shape \f$(N,C,d_1,\dots,d_n)\f$ where \f$n > 0\f$, every \f$d_i > 0\f$, and where \f$N\f$ is /// Max pooling takes as its input a data batch tensor of shape \f$(N,C,d_1,\dots,d_n)\f$ where \f$n > 0\f$, every \f$d_i > 0\f$, and where \f$N\f$ is
/// the batch size, and \f$C > 0\f$ is the number of channels (sometimes called features). It also takes two parameters: /// the batch size, and \f$C > 0\f$ is the number of channels (sometimes called features). The dimensions \f$(d_1,\dots,d_n)\f$ correspond to the shape of
/// an \f$n\f$-dimensional data item in a batch. For example, where \f$n=2\f$, the data may represent a two-dimensional image. It also takes two parameters:
/// ///
/// 1. <i>(the window shape)</i> a size vector \f$(w_1,\dots,w_n)\f$ where every \f$w_i \le d_i\f$; and /// 1. <i>(the window shape)</i> a size vector \f$(w_1,\dots,w_n)\f$ where every \f$w_i \le d_i\f$; and
/// 2. <i>(the window movement strides, optional)</i> a vector of positive integers \f$(s_1,\dots,s_n)\f$. /// 2. <i>(the window movement strides, optional)</i> a vector of positive integers \f$(s_1,\dots,s_n)\f$.
/// ///
/// The output has the shape \f$(N,C,d'_1,\dots,d'_n)\f$, where \f$d'_n = \lceil \frac{d_i - w_i + 1}{s_i} \rceil\f$. /// The output has the shape \f$(N,C,d'_1,\dots,d'_n)\f$, where \f$d'_n = \lceil \frac{d_i - w_i + 1}{s_i} \rceil\f$.
/// ///
/// Given an input image batch tensor \f$T_\textit{in}\f$, the output tensor is defined by the equation /// Given an input data batch tensor \f$T_\textit{in}\f$, the output tensor is defined by the equation
/// ///
/// \f[ /// \f[
/// T_\textit{out}[a,c,i_1,\dots,i_n] = \max_{j_1 = s_1 i_1, \dots, j_n = s_n i_n}^{j_1 = s_1 i_1 + w_1 - 1, \dots, j_n = s_n i_n + w_n - 1} (T_\textit{in}[a,c,j_1,\dots,j_n]) /// T_\textit{out}[a,c,i_1,\dots,i_n] = \max_{j_1 = s_1 i_1, \dots, j_n = s_n i_n}^{j_1 = s_1 i_1 + w_1 - 1, \dots, j_n = s_n i_n + w_n - 1} (T_\textit{in}[a,c,j_1,\dots,j_n])
...@@ -41,7 +42,7 @@ namespace ngraph ...@@ -41,7 +42,7 @@ namespace ngraph
public: public:
/// \brief Constructs a batched max pooling operation. /// \brief Constructs a batched max pooling operation.
/// ///
/// \param arg The node producing the input image batch tensor. /// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape. /// \param window_shape The window shape.
/// \param window_movement_strides The window movement strides. /// \param window_movement_strides The window movement strides.
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const std::shared_ptr<Node>& arg,
...@@ -50,7 +51,7 @@ namespace ngraph ...@@ -50,7 +51,7 @@ namespace ngraph
/// \brief Constructs an unstrided batched convolution operation (i.e., all window movement strides are 1). /// \brief Constructs an unstrided batched convolution operation (i.e., all window movement strides are 1).
/// ///
/// \param arg The node producing the input image batch tensor. /// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape. /// \param window_shape The window shape.
MaxPool(const std::shared_ptr<Node>& arg, const Shape& window_shape); MaxPool(const std::shared_ptr<Node>& arg, const Shape& window_shape);
...@@ -62,35 +63,18 @@ namespace ngraph ...@@ -62,35 +63,18 @@ namespace ngraph
return std::make_shared<MaxPool>( return std::make_shared<MaxPool>(
new_args.at(0), m_window_shape, m_window_movement_strides); new_args.at(0), m_window_shape, m_window_movement_strides);
} }
bool is_functionally_identical(const Node&) const override;
/// \return The window shape. /// \return The window shape.
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
/// \return The window movement strides. /// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
/// \return The number of image channels.
size_t get_channel_count() const { return m_channel_count; }
/// \return The input image shape.
const Shape& get_input_image_shape() const { return m_input_image_shape; }
/// \return The output image shape.
const Shape& get_output_image_shape() const { return m_output_image_shape; }
/// \return The batch size.
size_t get_batch_size() const { return m_batch_size; }
/// \return The number of image dimensions.
size_t get_image_dimension_count() const { return m_image_dimension_count; }
bool is_functionally_identical(const Node&) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override; const std::shared_ptr<Node>& delta) override;
Shape m_window_shape; Shape m_window_shape;
Strides m_window_movement_strides; Strides m_window_movement_strides;
size_t m_channel_count;
Shape m_input_image_shape;
Shape m_output_image_shape;
size_t m_batch_size;
size_t m_image_dimension_count;
}; };
} }
} }
#include "graph_rewrite.hpp" // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <unordered_set> #include <unordered_set>
#include "graph_rewrite.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list( bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
const std::list<std::shared_ptr<ngraph::Node>>& nodes, const std::list<std::shared_ptr<ngraph::Node>>& nodes,
const std::vector<std::shared_ptr<pattern::Matcher>>& matchers) const std::vector<std::shared_ptr<pattern::Matcher>>& matchers,
std::shared_ptr<ngraph::Function> f)
{ {
bool rewritten = false; bool rewritten = false;
for (auto node : nodes) for (auto node : nodes)
...@@ -15,23 +31,26 @@ bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list( ...@@ -15,23 +31,26 @@ bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
for (auto matcher : matchers) for (auto matcher : matchers)
{ {
NGRAPH_DEBUG << "Running matcher " << matcher << " on " << node << " , " NGRAPH_DEBUG << "Running matcher " << matcher << " on " << node << " , "
<< node->get_name(); << node->get_name() << " , is_output = " << node->is_output();
if (!node->is_output() /*this restriction can be lifted when we find an use case for it*/ if (matcher->match(node))
&&
matcher->match(node))
{ {
NGRAPH_DEBUG << "Matcher " << matcher << " matched " << node << " , " NGRAPH_DEBUG << "Matcher " << matcher << " matched " << node << " , "
<< node->get_name(); << node->get_name();
rewritten = true; rewritten = true;
matcher->process_match(); auto result = matcher->process_match();
break; //move onto the next node if (result)
{
f->replace_node(node, result);
//move onto the next node
break;
}
} }
} }
} }
return rewritten; return rewritten;
} }
bool ngraph::pass::GraphRewrite::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) bool ngraph::pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f)
{ {
return run_matchers_on_nodes_list(nodes, m_matchers); return run_matchers_on_nodes_list(f->get_ordered_ops(), m_matchers, f);
} }
...@@ -40,19 +40,21 @@ namespace ngraph ...@@ -40,19 +40,21 @@ namespace ngraph
/// Patterns can be added by using \sa add_matcher /// Patterns can be added by using \sa add_matcher
/// Callbacks should use \sa replace_node to transform matched sub graphs /// Callbacks should use \sa replace_node to transform matched sub graphs
class ngraph::pass::GraphRewrite : public CallGraphPass class ngraph::pass::GraphRewrite : public FunctionPass
{ {
public: public:
GraphRewrite() GraphRewrite()
: CallGraphPass() : FunctionPass()
{ {
} }
void add_matcher(std::shared_ptr<pattern::Matcher> m) { m_matchers.push_back(m); } void add_matcher(std::shared_ptr<pattern::Matcher> m) { m_matchers.push_back(m); }
virtual bool run_on_call_graph(const std::list<std::shared_ptr<ngraph::Node>>&) override;
static bool static bool
run_matchers_on_nodes_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes, run_matchers_on_nodes_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes,
const std::vector<std::shared_ptr<pattern::Matcher>>& matchers); const std::vector<std::shared_ptr<pattern::Matcher>>& matchers,
std::shared_ptr<ngraph::Function> f);
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
private: private:
//enable cascading rewrites //enable cascading rewrites
......
...@@ -63,8 +63,10 @@ namespace ngraph ...@@ -63,8 +63,10 @@ namespace ngraph
auto args = get_arguments(label); auto args = get_arguments(label);
if (args.size() > 0) if (args.size() > 0)
{ {
assert(args.size() == if (args.size() != 1)
1); //it should be impossible to construct labels w/ more than one arg {
throw ngraph_error("Labels can only take 1 argument!");
}
NGRAPH_DEBUG << "[MATCHER] Label describes a sub graph in the pattern"; NGRAPH_DEBUG << "[MATCHER] Label describes a sub graph in the pattern";
is_match = match_node(args.at(0), graph_node, pattern_map); is_match = match_node(args.at(0), graph_node, pattern_map);
} }
...@@ -92,7 +94,11 @@ namespace ngraph ...@@ -92,7 +94,11 @@ namespace ngraph
else else
{ {
auto args = get_arguments(any); auto args = get_arguments(any);
assert(args.size() == 1); if (args.size() != 1)
{
throw ngraph_error("Any can only take one argument");
}
return match_node(args.at(0), graph_node, pattern_map); return match_node(args.at(0), graph_node, pattern_map);
} }
} }
...@@ -101,7 +107,10 @@ namespace ngraph ...@@ -101,7 +107,10 @@ namespace ngraph
const std::shared_ptr<Node>& graph_node, const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map) PatternMap& pattern_map)
{ {
assert(pattern_node && graph_node); if (!pattern_node || !graph_node)
{
throw ngraph_error("pattern_node or graph_node shouldn't be nullptrs!");
}
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_node : " NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_node : "
<< "pattern = " << pattern_node->get_name() << " matched " << "pattern = " << pattern_node->get_name() << " matched "
...@@ -191,17 +200,24 @@ namespace ngraph ...@@ -191,17 +200,24 @@ namespace ngraph
return false; return false;
} }
void Matcher::process_match(::ngraph::pattern::gr_callback_fn callback) std::shared_ptr<Node> Matcher::process_match(::ngraph::pattern::gr_callback_fn callback)
{ {
gr_callback_fn cb = m_callback; gr_callback_fn cb = m_callback;
if (callback) if (callback)
{ {
cb = callback; cb = callback;
} }
if (!cb)
{
throw ngraph_error("process_match invoked w/o a callback function");
}
if (!this->m_match_root)
{
throw ngraph_error("process_match invoked w/o a match");
}
assert(cb); return cb(*this);
assert(this->m_match_root);
cb(*this);
} }
static Nodes get_users(std::shared_ptr<Node> node) static Nodes get_users(std::shared_ptr<Node> node)
......
...@@ -29,7 +29,7 @@ namespace ngraph ...@@ -29,7 +29,7 @@ namespace ngraph
namespace pattern namespace pattern
{ {
using gr_callback_fn = std::function<void(class Matcher& m)>; using gr_callback_fn = std::function<std::shared_ptr<Node>(class Matcher& m)>;
namespace op namespace op
{ {
...@@ -60,7 +60,7 @@ namespace ngraph ...@@ -60,7 +60,7 @@ namespace ngraph
/// \param graph_node is an input graph to be matched against /// \param graph_node is an input graph to be matched against
bool match(const std::shared_ptr<Node>& graph_node); bool match(const std::shared_ptr<Node>& graph_node);
void process_match(gr_callback_fn callback = nullptr); std::shared_ptr<Node> process_match(gr_callback_fn callback = nullptr);
void reset() {} void reset() {}
std::shared_ptr<Node> pattern_node() { return m_pattern_node; } std::shared_ptr<Node> pattern_node() { return m_pattern_node; }
......
This diff is collapsed.
...@@ -60,6 +60,7 @@ namespace ngraph ...@@ -60,6 +60,7 @@ namespace ngraph
static void EMITTER_DECL(EmitSelect); static void EMITTER_DECL(EmitSelect);
static void EMITTER_DECL(EmitSubtract); static void EMITTER_DECL(EmitSubtract);
static void EMITTER_DECL(EmitBroadcast); static void EMITTER_DECL(EmitBroadcast);
static void EMITTER_DECL(EmitMatmulBias);
static void EMITTER_DECL(EmitConvert); static void EMITTER_DECL(EmitConvert);
static void EMITTER_DECL(EmitConstant); static void EMITTER_DECL(EmitConstant);
static void EMITTER_DECL(EmitReshape); static void EMITTER_DECL(EmitReshape);
...@@ -85,6 +86,8 @@ namespace ngraph ...@@ -85,6 +86,8 @@ namespace ngraph
static void EMITTER_DECL(EmitCeiling); static void EMITTER_DECL(EmitCeiling);
static void EMITTER_DECL(EmitSqrt); static void EMITTER_DECL(EmitSqrt);
static void EMITTER_DECL(EmitConvolution); static void EMITTER_DECL(EmitConvolution);
static void EMITTER_DECL(EmitConvolutionBackpropFilters);
static void EMITTER_DECL(EmitConvolutionBackpropData);
static void EMITTER_DECL(EmitNot); static void EMITTER_DECL(EmitNot);
static void EMITTER_DECL(EmitMaxPool); static void EMITTER_DECL(EmitMaxPool);
static void EMITTER_DECL(EmitReverse); static void EMITTER_DECL(EmitReverse);
...@@ -93,6 +96,8 @@ namespace ngraph ...@@ -93,6 +96,8 @@ namespace ngraph
static void EMITTER_DECL(EmitAvgPool); static void EMITTER_DECL(EmitAvgPool);
static void EMITTER_DECL(EmitPad); static void EMITTER_DECL(EmitPad);
static void EmitMKLDNNPreamble(codegen::CodeWriter& writer);
private: private:
static std::string emit_vector(const TensorViewWrapper&, static std::string emit_vector(const TensorViewWrapper&,
const std::string& name = ""); const std::string& name = "");
......
...@@ -92,6 +92,7 @@ ...@@ -92,6 +92,7 @@
#include "ngraph/runtime/cpu/cpu_call_frame.hpp" #include "ngraph/runtime/cpu/cpu_call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_emitter.hpp" #include "ngraph/runtime/cpu/cpu_emitter.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp" #include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/runtime/host_tensor_view.hpp" #include "ngraph/runtime/host_tensor_view.hpp"
using namespace std; using namespace std;
...@@ -143,6 +144,7 @@ static StaticInitializers s_static_initializers; ...@@ -143,6 +144,7 @@ static StaticInitializers s_static_initializers;
static const runtime::cpu::OpMap dispatcher{ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::CPU_Emitter::EmitAdd}, {TI(ngraph::op::Add), &runtime::cpu::CPU_Emitter::EmitAdd},
{TI(ngraph::op::MatmulBias), &runtime::cpu::CPU_Emitter::EmitMatmulBias},
{TI(ngraph::op::Dot), &runtime::cpu::CPU_Emitter::EmitDot}, {TI(ngraph::op::Dot), &runtime::cpu::CPU_Emitter::EmitDot},
{TI(ngraph::op::Multiply), &runtime::cpu::CPU_Emitter::EmitMultiply}, {TI(ngraph::op::Multiply), &runtime::cpu::CPU_Emitter::EmitMultiply},
{TI(ngraph::op::Parameter), &runtime::cpu::CPU_Emitter::EmitNop}, {TI(ngraph::op::Parameter), &runtime::cpu::CPU_Emitter::EmitNop},
...@@ -187,6 +189,10 @@ static const runtime::cpu::OpMap dispatcher{ ...@@ -187,6 +189,10 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Ceiling), &runtime::cpu::CPU_Emitter::EmitCeiling}, {TI(ngraph::op::Ceiling), &runtime::cpu::CPU_Emitter::EmitCeiling},
{TI(ngraph::op::Sqrt), &runtime::cpu::CPU_Emitter::EmitSqrt}, {TI(ngraph::op::Sqrt), &runtime::cpu::CPU_Emitter::EmitSqrt},
{TI(ngraph::op::Convolution), &runtime::cpu::CPU_Emitter::EmitConvolution}, {TI(ngraph::op::Convolution), &runtime::cpu::CPU_Emitter::EmitConvolution},
{TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::CPU_Emitter::EmitConvolutionBackpropFilters},
{TI(ngraph::op::ConvolutionBackpropData),
&runtime::cpu::CPU_Emitter::EmitConvolutionBackpropData},
{TI(ngraph::op::Not), &runtime::cpu::CPU_Emitter::EmitNot}, {TI(ngraph::op::Not), &runtime::cpu::CPU_Emitter::EmitNot},
{TI(ngraph::op::MaxPool), &runtime::cpu::CPU_Emitter::EmitMaxPool}, {TI(ngraph::op::MaxPool), &runtime::cpu::CPU_Emitter::EmitMaxPool},
{TI(ngraph::op::Reverse), &runtime::cpu::CPU_Emitter::EmitReverse}, {TI(ngraph::op::Reverse), &runtime::cpu::CPU_Emitter::EmitReverse},
...@@ -481,6 +487,8 @@ using namespace ngraph::runtime; ...@@ -481,6 +487,8 @@ using namespace ngraph::runtime;
writer << "tbb::flow::graph G;\n\n"; writer << "tbb::flow::graph G;\n\n";
} }
runtime::cpu::CPU_Emitter::EmitMKLDNNPreamble(writer);
bool temporaries_used = false; bool temporaries_used = false;
size_t worst_case_tmp_size = 0; size_t worst_case_tmp_size = 0;
for (shared_ptr<Node> node : current_function->get_ordered_ops()) for (shared_ptr<Node> node : current_function->get_ordered_ops())
......
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "matmul_bias.hpp"
std::shared_ptr<ngraph::Node> ngraph::op::MatmulBias::copy_with_new_args(
const std::vector<std::shared_ptr<ngraph::Node>>& new_args) const
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<MatmulBias>(new_args.at(0),
new_args.at(1),
new_args.at(1),
m_shape_w,
m_shape_x,
m_transpose_w,
m_transpose_x);
}
ngraph::op::MatmulBias::MatmulBias(std::shared_ptr<ngraph::Node> W,
std::shared_ptr<ngraph::Node> x,
std::shared_ptr<ngraph::Node> b,
Shape shape_w,
Shape shape_x,
bool transpose_w,
bool transpose_x)
: RequiresTensorViewArgs("CblassGemm", {W, x, b})
, m_shape_w(shape_w)
, m_shape_x(shape_x)
, m_transpose_w(transpose_w)
, m_transpose_x(transpose_x)
{
if (shape_w.size() != 2)
{
NGRAPH_DEBUG << "W shape = " << vector_to_string(shape_w);
throw ngraph_error("W.shape.rank != 2 while creating MatmulBias");
}
if (shape_x.size() != 2)
{
NGRAPH_DEBUG << "x shape = " << vector_to_string(shape_x);
throw ngraph_error("x.shape.rank != 2 while creating MatmulBias");
}
size_t dot_dimension_w = (transpose_w) ? 0 : 1;
size_t dot_dimension_x = (transpose_x) ? 1 : 0;
NGRAPH_DEBUG << "dot_dimension_w = " << dot_dimension_w
<< " , dot_dimension_x = " << dot_dimension_x;
NGRAPH_DEBUG << "W shape = " << vector_to_string(shape_w)
<< " , x shape = " << vector_to_string(shape_x);
if (shape_w.at(dot_dimension_w) != shape_x.at(dot_dimension_x))
{
throw ngraph_error("product dimensions are not equal while creating MatmulBias");
}
auto dot_shape = Shape{shape_w.at(1 - dot_dimension_w), shape_x.at(1 - dot_dimension_x)};
NGRAPH_DEBUG << "dot_shape shape = " << vector_to_string(dot_shape)
<< " , b shape = " << vector_to_string(b->get_shape());
add_output(W->get_element_type(), dot_shape);
}
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/util.hpp"
#include <memory>
namespace ngraph
{
namespace op
{
class MatmulBias : public RequiresTensorViewArgs
{
public:
MatmulBias(std::shared_ptr<Node> W,
std::shared_ptr<Node> x,
std::shared_ptr<Node> b,
Shape shape_w,
Shape shape_x,
bool transpose_w,
bool transpose_x);
bool get_is_arg0_transposed() const { return m_transpose_w; }
bool get_is_arg1_transposed() const { return m_transpose_x; }
Shape get_arg0_shape() const { return m_shape_w; }
Shape get_arg1_shape() const { return m_shape_x; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override;
private:
Shape m_shape_w;
Shape m_shape_x;
bool m_transpose_w;
bool m_transpose_x;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "cpu_fusion.hpp"
#include <algorithm>
#include <iostream>
#include <unordered_set>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
std::shared_ptr<ngraph::Node> arg,
bool& transpose_w,
ngraph::Shape& shape_w)
{
auto r_w = std::dynamic_pointer_cast<ngraph::op::Reshape>(reshape);
if (!r_w)
{
return true; //nth to do; reshape isn't a reshape
}
if (r_w->get_shape().size() != 2)
{
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " doesn't reshape into matrix"
<< ngraph::vector_to_string(r_w->get_shape());
return false;
}
auto io = r_w->get_input_order();
if (r_w->get_shape().size() != arg->get_shape().size()) //reshape
{
ngraph::AxisVector dio(io.size());
std::iota(begin(dio), end(dio), 0);
if (io != dio) //we can't reshape and transpose at the same time
{
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " is not in default order "
<< ngraph::vector_to_string(io);
NGRAPH_DEBUG << "r_w shape = " << ngraph::vector_to_string(r_w->get_shape());
NGRAPH_DEBUG << "arg shape = " << ngraph::vector_to_string(arg->get_shape());
return false;
}
shape_w = r_w->get_shape();
}
else
{
if (io == ngraph::AxisVector{1, 0})
{
transpose_w = true;
}
//otherwise no-op reshape
}
return true;
}
template <typename T>
static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order)
{
if (input.size() != order.size())
{
throw "input and order sizes don't match!";
}
std::vector<T> output(input.size());
for (size_t i = 0; i < order.size(); i++)
{
output[i] = input.at(order.at(i));
}
return output;
}
void ngraph::pass::CPUFusion::construct_gemm_pattern()
{
auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1};
auto shape_b = Shape{1};
auto shape_dot = Shape{2, 1};
auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w);
auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
auto reshape_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
};
auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred);
auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto pdot = std::make_shared<op::Dot>(skip_w, skip_x);
auto b = std::make_shared<pattern::op::Label>(element::f32, shape_b);
auto pbroadcast = std::make_shared<op::Broadcast>(b, shape_dot, AxisSet{0});
auto padd = pdot + pbroadcast;
ngraph::pattern::gr_callback_fn callback = [W, x, b](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_gemm_pattern against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
std::shared_ptr<Node> nn = nullptr;
auto mpattern = m.match_root();
if (mpattern->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!";
return nn;
}
auto dot = mpattern->get_input_op(0);
if (dot->get_shape().size() != 2)
{
NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!";
return nn;
}
bool transpose_w = false;
Shape shape_arg0{pattern_map[W]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0))
{
return nn;
}
bool transpose_x = false;
Shape shape_arg1{pattern_map[x]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(1), pattern_map[x], transpose_x, shape_arg1))
{
return nn;
}
auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W],
pattern_map[x],
mpattern->get_input_op(1),
shape_arg0,
shape_arg1,
transpose_w,
transpose_x));
return cg;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback);
this->add_matcher(m);
}
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace pass
{
class CPUFusion;
}
}
class ngraph::pass::CPUFusion : public ngraph::pass::GraphRewrite
{
public:
CPUFusion()
: GraphRewrite()
{
construct_gemm_pattern();
}
private:
void construct_gemm_pattern();
};
...@@ -322,7 +322,59 @@ private: ...@@ -322,7 +322,59 @@ private:
c->get_window_dilation_strides(), c->get_window_dilation_strides(),
c->get_padding_below(), c->get_padding_below(),
c->get_padding_above(), c->get_padding_above(),
c->get_image_dilation_strides()); c->get_data_dilation_strides(),
0,
1,
1,
0,
0,
1,
false);
}
else if (node_op == "ConvolutionBackpropFilters")
{
auto c = static_cast<const op::ConvolutionBackpropFilters*>(&node);
kernel::convolution<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
c->get_padding_above_backward(),
c->get_data_dilation_strides_backward(),
1,
0,
0,
1,
1,
0,
false);
}
else if (node_op == "ConvolutionBackpropData")
{
// Note that args[1] and args[0] are switched here from the usual order.
auto c = static_cast<const op::ConvolutionBackpropData*>(&node);
kernel::convolution<T>(reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[1]->get_shape(),
args[0]->get_shape(),
out[0]->get_shape(),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
c->get_padding_above_backward(),
c->get_data_dilation_strides_backward(),
0,
1,
0,
1,
0,
1,
true);
} }
else if (node_op == "Cos") else if (node_op == "Cos")
{ {
......
...@@ -42,36 +42,36 @@ namespace ngraph ...@@ -42,36 +42,36 @@ namespace ngraph
{ {
// Our output coordinate O will have the form: // Our output coordinate O will have the form:
// //
// (img,chan,i_1,...,i_n) // (N,chan,i_1,...,i_n)
size_t img_index = out_coord[0]; size_t batch_index = out_coord[0];
size_t channel = out_coord[1]; size_t channel = out_coord[1];
// For the input images we need to iterate the coordinate: // For the input data we need to iterate the coordinate:
// //
// I: // I:
// //
// over the range (noninclusive on the right): // over the range (noninclusive on the right):
// //
// (img,chan,s_1*i_1,s_2*i_2,...,s_n*i_n) -> // (N,chan,s_1*i_1,s_2*i_2,...,s_n*i_n) ->
// //
// (img+1,chan+1,s_1*i_1 + window_shape_1,...,s_n*i_n + window_shape_n) // (N+1,chan+1,s_1*i_1 + window_shape_1,...,s_n*i_n + window_shape_n)
// //
// with unit stride. // with unit stride.
// //
// We iterate this over the *padded* image, so below we will need to check for coordinates that fall in the padding area. // We iterate this over the *padded* data, so below we will need to check for coordinates that fall in the padding area.
size_t n_image_dimensions = arg_shape.size() - 2; size_t n_spatial_dimensions = arg_shape.size() - 2;
Coordinate input_batch_transform_start(2 + n_image_dimensions); Coordinate input_batch_transform_start(2 + n_spatial_dimensions);
Coordinate input_batch_transform_end(2 + n_image_dimensions); Coordinate input_batch_transform_end(2 + n_spatial_dimensions);
Strides input_batch_transform_source_strides(2 + n_image_dimensions, 1); Strides input_batch_transform_source_strides(2 + n_spatial_dimensions, 1);
AxisVector input_batch_transform_source_axis_order(2 + n_image_dimensions); AxisVector input_batch_transform_source_axis_order(2 + n_spatial_dimensions);
CoordinateDiff input_batch_transform_padding_below(2 + n_image_dimensions); CoordinateDiff input_batch_transform_padding_below(2 + n_spatial_dimensions);
CoordinateDiff input_batch_transform_padding_above(2 + n_image_dimensions); CoordinateDiff input_batch_transform_padding_above(2 + n_spatial_dimensions);
input_batch_transform_start[0] = img_index; input_batch_transform_start[0] = batch_index;
input_batch_transform_end[0] = img_index + 1; input_batch_transform_end[0] = batch_index + 1;
input_batch_transform_start[1] = channel; input_batch_transform_start[1] = channel;
input_batch_transform_end[1] = channel + 1; input_batch_transform_end[1] = channel + 1;
input_batch_transform_padding_below[0] = 0; input_batch_transform_padding_below[0] = 0;
...@@ -79,7 +79,7 @@ namespace ngraph ...@@ -79,7 +79,7 @@ namespace ngraph
input_batch_transform_padding_above[0] = 0; input_batch_transform_padding_above[0] = 0;
input_batch_transform_padding_above[1] = 0; input_batch_transform_padding_above[1] = 0;
for (size_t i = 2; i < n_image_dimensions + 2; i++) for (size_t i = 2; i < n_spatial_dimensions + 2; i++)
{ {
size_t window_shape_this_dim = window_shape[i - 2]; size_t window_shape_this_dim = window_shape[i - 2];
size_t movement_stride = window_movement_strides[i - 2]; size_t movement_stride = window_movement_strides[i - 2];
......
...@@ -37,8 +37,23 @@ namespace ngraph ...@@ -37,8 +37,23 @@ namespace ngraph
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& image_dilation_strides) const Strides& data_dilation_strides,
size_t batch_axis_data,
size_t input_channel_axis_data,
size_t input_channel_axis_filters,
size_t output_channel_axis_filters,
size_t batch_axis_result,
size_t output_channel_axis_result,
bool rotate_filter)
{ {
// Comments throughout assume without loss of generality that:
//
// * batch axes for both input data and output data are 0
// * input channel axes for both input data and filters are 1
// * output channel axes for filters is 0
// * output channel axis for output data is 1
// * rotate_filter is false
// At the outermost level we will walk over every output coordinate O. // At the outermost level we will walk over every output coordinate O.
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
...@@ -46,50 +61,50 @@ namespace ngraph ...@@ -46,50 +61,50 @@ namespace ngraph
{ {
// Our output coordinate O will have the form: // Our output coordinate O will have the form:
// //
// (img,chan_out,i_1,...,i_n) // (N,chan_out,i_1,...,i_n)
size_t img_index = out_coord[0]; size_t batch_index = out_coord[batch_axis_result];
size_t output_channel = out_coord[1]; size_t output_channel = out_coord[output_channel_axis_result];
// For the input images we need to iterate the coordinate: // For the input data we need to iterate the coordinate:
// //
// I: // I:
// //
// over the range (noninclusive on the right): // over the range (noninclusive on the right):
// //
// (img,0,s_1*i_1,s_2*i_2,...,s_n*i_n) -> // (N,0,s_1*i_1,s_2*i_2,...,s_n*i_n) ->
// //
// (img+1,chans_in_count,s_1*i_1 + l_1*filter_dims_1,...,s_n*i_n + l_n*filter_dims_n) // (N+1,chans_in_count,s_1*i_1 + l_1*filter_dims_1,...,s_n*i_n + l_n*filter_dims_n)
// //
// with strides: // with strides:
// //
// (1,l_1,...,l_n). // (1,l_1,...,l_n).
// //
// Note that we are iterating within the *padded* and *dilated* image batch, so further // Note that we are iterating within the *padded* and *dilated* data batch, so further
// down we must check the current coordinate is in the padding or dilation gap. // down we must check the current coordinate is in the padding or dilation gap.
size_t n_image_dimensions = arg0_shape.size() - 2; size_t n_spatial_dimensions = arg0_shape.size() - 2;
size_t n_input_channels = arg0_shape[1]; size_t n_input_channels = arg0_shape[input_channel_axis_data];
Coordinate input_batch_transform_start(2 + n_image_dimensions); Coordinate input_batch_transform_start(2 + n_spatial_dimensions);
Coordinate input_batch_transform_end(2 + n_image_dimensions); Coordinate input_batch_transform_end(2 + n_spatial_dimensions);
Strides input_batch_transform_movement_strides(2 + n_image_dimensions, 1); Strides input_batch_transform_movement_strides(2 + n_spatial_dimensions, 1);
CoordinateDiff input_batch_transform_padding_below(2 + n_image_dimensions, 0); CoordinateDiff input_batch_transform_padding_below(2 + n_spatial_dimensions, 0);
CoordinateDiff input_batch_transform_padding_above(2 + n_image_dimensions, 0); CoordinateDiff input_batch_transform_padding_above(2 + n_spatial_dimensions, 0);
Strides input_batch_transform_dilation_strides(2 + n_image_dimensions, 1); Strides input_batch_transform_dilation_strides(2 + n_spatial_dimensions, 1);
input_batch_transform_start[0] = img_index; input_batch_transform_start[batch_axis_data] = batch_index;
input_batch_transform_end[0] = img_index + 1; input_batch_transform_end[batch_axis_data] = batch_index + 1;
input_batch_transform_start[1] = 0; input_batch_transform_start[input_channel_axis_data] = 0;
input_batch_transform_end[1] = n_input_channels; input_batch_transform_end[input_channel_axis_data] = n_input_channels;
for (size_t i = 2; i < n_image_dimensions + 2; i++) for (size_t i = 2; i < n_spatial_dimensions + 2; i++)
{ {
size_t window_dilation_stride = window_dilation_strides[i - 2]; size_t window_dilation_stride = window_dilation_strides[i - 2];
size_t window_movement_stride = window_movement_strides[i - 2]; size_t window_movement_stride = window_movement_strides[i - 2];
std::ptrdiff_t below_pad = padding_below[i - 2]; std::ptrdiff_t below_pad = padding_below[i - 2];
std::ptrdiff_t above_pad = padding_above[i - 2]; std::ptrdiff_t above_pad = padding_above[i - 2];
size_t image_dilation_stride = image_dilation_strides[i - 2]; size_t data_dilation_stride = data_dilation_strides[i - 2];
input_batch_transform_start[i] = window_movement_stride * out_coord[i]; input_batch_transform_start[i] = window_movement_stride * out_coord[i];
input_batch_transform_end[i] = input_batch_transform_end[i] =
...@@ -98,10 +113,10 @@ namespace ngraph ...@@ -98,10 +113,10 @@ namespace ngraph
input_batch_transform_movement_strides[i] = window_dilation_stride; input_batch_transform_movement_strides[i] = window_dilation_stride;
input_batch_transform_padding_below[i] = below_pad; input_batch_transform_padding_below[i] = below_pad;
input_batch_transform_padding_above[i] = above_pad; input_batch_transform_padding_above[i] = above_pad;
input_batch_transform_dilation_strides[i] = image_dilation_stride; input_batch_transform_dilation_strides[i] = data_dilation_stride;
} }
AxisVector input_batch_transform_axis_order(2 + n_image_dimensions); AxisVector input_batch_transform_axis_order(2 + n_spatial_dimensions);
size_t n = 0; size_t n = 0;
std::generate(input_batch_transform_axis_order.begin(), std::generate(input_batch_transform_axis_order.begin(),
input_batch_transform_axis_order.end(), input_batch_transform_axis_order.end(),
...@@ -127,15 +142,15 @@ namespace ngraph ...@@ -127,15 +142,15 @@ namespace ngraph
// //
// with unit stride. // with unit stride.
Shape filter_transform_start(2 + n_image_dimensions); Shape filter_transform_start(2 + n_spatial_dimensions);
Shape filter_transform_end(2 + n_image_dimensions); Shape filter_transform_end(2 + n_spatial_dimensions);
filter_transform_start[0] = output_channel; filter_transform_start[output_channel_axis_filters] = output_channel;
filter_transform_end[0] = output_channel + 1; filter_transform_end[output_channel_axis_filters] = output_channel + 1;
filter_transform_start[1] = 0; filter_transform_start[input_channel_axis_filters] = 0;
filter_transform_end[1] = n_input_channels; filter_transform_end[input_channel_axis_filters] = n_input_channels;
for (size_t i = 2; i < n_image_dimensions + 2; i++) for (size_t i = 2; i < n_spatial_dimensions + 2; i++)
{ {
filter_transform_start[i] = 0; filter_transform_start[i] = 0;
filter_transform_end[i] = arg1_shape[i]; filter_transform_end[i] = arg1_shape[i];
...@@ -157,7 +172,19 @@ namespace ngraph ...@@ -157,7 +172,19 @@ namespace ngraph
filter_it != filter_transform.end()) filter_it != filter_transform.end())
{ {
const Coordinate& input_batch_coord = *input_it; const Coordinate& input_batch_coord = *input_it;
const Coordinate& filter_coord = *filter_it; Coordinate filter_coord = *filter_it;
if (rotate_filter)
{
Shape target_shape = filter_transform.get_target_shape();
// Note that we only reverse the spatial dimensions here (loop
// starts at 2)
for (size_t i = 2; i < filter_coord.size(); i++)
{
filter_coord[i] = target_shape[i] - filter_coord[i] - 1;
}
}
T v = input_batch_transform.has_source_coordinate(input_batch_coord) T v = input_batch_transform.has_source_coordinate(input_batch_coord)
? arg0[input_batch_transform.index(input_batch_coord)] ? arg0[input_batch_transform.index(input_batch_coord)]
......
...@@ -40,34 +40,34 @@ namespace ngraph ...@@ -40,34 +40,34 @@ namespace ngraph
{ {
// Our output coordinate O will have the form: // Our output coordinate O will have the form:
// //
// (img,chan,i_1,...,i_n) // (N,chan,i_1,...,i_n)
size_t img_index = out_coord[0]; size_t batch_index = out_coord[0];
size_t channel = out_coord[1]; size_t channel = out_coord[1];
// For the input images we need to iterate the coordinate: // For the input data we need to iterate the coordinate:
// //
// I: // I:
// //
// over the range (noninclusive on the right): // over the range (noninclusive on the right):
// //
// (img,chan,s_1*i_1,s_2*i_2,...,s_n*i_n) -> // (N,chan,s_1*i_1,s_2*i_2,...,s_n*i_n) ->
// //
// (img+1,chan+1,s_1*i_1 + window_shape_1,...,s_n*i_n + window_shape_n) // (N+1,chan+1,s_1*i_1 + window_shape_1,...,s_n*i_n + window_shape_n)
// //
// with unit stride. // with unit stride.
size_t n_image_dimensions = arg_shape.size() - 2; size_t n_spatial_dimensions = arg_shape.size() - 2;
Coordinate input_batch_transform_start(2 + n_image_dimensions); Coordinate input_batch_transform_start(2 + n_spatial_dimensions);
Coordinate input_batch_transform_end(2 + n_image_dimensions); Coordinate input_batch_transform_end(2 + n_spatial_dimensions);
input_batch_transform_start[0] = img_index; input_batch_transform_start[0] = batch_index;
input_batch_transform_end[0] = img_index + 1; input_batch_transform_end[0] = batch_index + 1;
input_batch_transform_start[1] = channel; input_batch_transform_start[1] = channel;
input_batch_transform_end[1] = channel + 1; input_batch_transform_end[1] = channel + 1;
for (size_t i = 2; i < n_image_dimensions + 2; i++) for (size_t i = 2; i < n_spatial_dimensions + 2; i++)
{ {
size_t window_shape_this_dim = window_shape[i - 2]; size_t window_shape_this_dim = window_shape[i - 2];
size_t movement_stride = window_movement_strides[i - 2]; size_t movement_stride = window_movement_strides[i - 2];
......
...@@ -379,15 +379,79 @@ static shared_ptr<ngraph::Function> ...@@ -379,15 +379,79 @@ static shared_ptr<ngraph::Function>
node_js.at("window_dilation_strides").get<vector<size_t>>(); node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>(); auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>(); auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto image_dilation_strides =
node_js.at("image_dilation_strides").get<vector<size_t>>(); // For backwards compatibility, we accept "image_dilation_strides" in place of
node = make_shared<op::Convolution>(args[0], // "data_dilation_strides", and we also allow it to be omitted altogether.
args[1], auto data_dilation_strides_maybe = node_js["data_dilation_strides"];
window_movement_strides, if (data_dilation_strides_maybe.empty())
window_dilation_strides, {
padding_below, data_dilation_strides_maybe = node_js["image_dilation_strides"];
padding_above, }
image_dilation_strides);
if (data_dilation_strides_maybe.empty())
{
node = make_shared<op::Convolution>(args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above);
}
else
{
node = make_shared<op::Convolution>(
args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides_maybe.get<std::vector<size_t>>());
}
}
else if (node_op == "ConvolutionBackpropData")
{
auto data_batch_shape = node_js.at("data_batch_shape").get<vector<size_t>>();
auto window_movement_strides_forward =
node_js.at("window_movement_strides_forward").get<vector<size_t>>();
auto window_dilation_strides_forward =
node_js.at("window_dilation_strides_forward").get<vector<size_t>>();
auto padding_below_forward =
node_js.at("padding_below_forward").get<vector<std::ptrdiff_t>>();
auto padding_above_forward =
node_js.at("padding_above_forward").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides_forward =
node_js.at("data_dilation_strides_forward").get<vector<size_t>>();
node = make_shared<op::ConvolutionBackpropData>(data_batch_shape,
args[0],
args[1],
window_movement_strides_forward,
window_dilation_strides_forward,
padding_below_forward,
padding_above_forward,
data_dilation_strides_forward);
}
else if (node_op == "ConvolutionBackpropFilters")
{
auto filters_shape = node_js.at("filters_shape").get<vector<size_t>>();
auto window_movement_strides_forward =
node_js.at("window_movement_strides_forward").get<vector<size_t>>();
auto window_dilation_strides_forward =
node_js.at("window_dilation_strides_forward").get<vector<size_t>>();
auto padding_below_forward =
node_js.at("padding_below_forward").get<vector<std::ptrdiff_t>>();
auto padding_above_forward =
node_js.at("padding_above_forward").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides_forward =
node_js.at("data_dilation_strides_forward").get<vector<size_t>>();
node = make_shared<op::ConvolutionBackpropFilters>(args[0],
filters_shape,
args[1],
window_movement_strides_forward,
window_dilation_strides_forward,
padding_below_forward,
padding_above_forward,
data_dilation_strides_forward);
} }
else if (node_op == "Cos") else if (node_op == "Cos")
{ {
...@@ -718,7 +782,27 @@ static json write(const Node& n) ...@@ -718,7 +782,27 @@ static json write(const Node& n)
node["window_dilation_strides"] = tmp->get_window_dilation_strides(); node["window_dilation_strides"] = tmp->get_window_dilation_strides();
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["image_dilation_strides"] = tmp->get_image_dilation_strides(); node["data_dilation_strides"] = tmp->get_data_dilation_strides();
}
else if (node_op == "ConvolutionBackpropData")
{
auto tmp = dynamic_cast<const op::ConvolutionBackpropData*>(&n);
node["data_batch_shape"] = tmp->get_data_batch_shape();
node["window_movement_strides_forward"] = tmp->get_window_movement_strides_forward();
node["window_dilation_strides_forward"] = tmp->get_window_dilation_strides_forward();
node["padding_below_forward"] = tmp->get_padding_below_forward();
node["padding_above_forward"] = tmp->get_padding_above_forward();
node["data_dilation_strides_forward"] = tmp->get_data_dilation_strides_forward();
}
else if (node_op == "ConvolutionBackpropFilters")
{
auto tmp = dynamic_cast<const op::ConvolutionBackpropFilters*>(&n);
node["filters_shape"] = tmp->get_filters_shape();
node["window_movement_strides_forward"] = tmp->get_window_movement_strides_forward();
node["window_dilation_strides_forward"] = tmp->get_window_dilation_strides_forward();
node["padding_below_forward"] = tmp->get_padding_below_forward();
node["padding_above_forward"] = tmp->get_padding_above_forward();
node["data_dilation_strides_forward"] = tmp->get_data_dilation_strides_forward();
} }
else if (node_op == "Cos") else if (node_op == "Cos")
{ {
......
...@@ -69,7 +69,7 @@ endif() ...@@ -69,7 +69,7 @@ endif()
if(NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR) if(NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR)
include_directories(SYSTEM ${LLVM_INCLUDE_DIR}) include_directories(SYSTEM ${LLVM_INCLUDE_DIR})
link_directories(${LLVM_LIB_DIR}) link_directories(${LLVM_LIB_DIR})
set(SRC ${SRC} backend_performance.cpp codegen.cpp) set(SRC ${SRC} backend_performance.cpp codegen.cpp cpu_fusion.cpp)
set(BACKEND_NAMES ${BACKEND_NAMES} "CPU") set(BACKEND_NAMES ${BACKEND_NAMES} "CPU")
endif() endif()
......
This source diff could not be displayed because it is too large. You can view the blob instead.
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
//
#include "ngraph/file_util.hpp"
#include "ngraph/json.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "util/matcher.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(cpu_fusion, gemm_pattern)
{
auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1};
auto shape_b = Shape{1};
auto A = make_shared<op::Parameter>(element::f32, shape_w);
auto B = make_shared<op::Parameter>(element::f32, shape_x);
auto C = make_shared<op::Parameter>(element::f32, shape_b);
auto dot = make_shared<op::Dot>(A, B);
auto broadcast = make_shared<op::Broadcast>(C, dot->get_shape(), AxisSet{0});
auto add = dot + broadcast;
auto W = std::make_shared<pattern::op::Label>(A);
auto x = std::make_shared<pattern::op::Label>(B);
auto reshape_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
};
auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred);
auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto pdot = make_shared<op::Dot>(skip_w, skip_x);
auto b = std::make_shared<pattern::op::Label>(C);
auto pbroadcast = make_shared<op::Broadcast>(b, dot->get_shape(), AxisSet{0});
auto padd = pdot + pbroadcast;
TestMatcher n(nullptr);
ASSERT_TRUE(n.match(padd, add));
ASSERT_EQ(n.get_pattern_map()[W], A);
ASSERT_EQ(n.get_pattern_map()[x], B);
ASSERT_EQ(n.get_pattern_map()[b], C);
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, W->get_shape());
auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, x->get_shape());
auto re_dot = make_shared<op::Dot>(reshape_w, reshape_x);
auto re_add = re_dot + broadcast;
ASSERT_TRUE(n.match(padd, re_add));
ASSERT_EQ(n.get_pattern_map()[W], A);
ASSERT_EQ(n.get_pattern_map()[x], B);
ASSERT_EQ(n.get_pattern_map()[b], C);
auto cg =
make_shared<op::MatmulBias>(W, x, broadcast, W->get_shape(), x->get_shape(), false, false);
}
TEST(cpu_fusion, gemm_cpu)
{
auto shapeA = Shape{3, 2};
auto shapeB = Shape{2, 3};
auto shapeC = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto B = make_shared<op::Parameter>(element::f32, shapeB);
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, Shape{2, 3});
auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, Shape{3, 2});
auto one = op::Constant::create<float>(element::f32, Shape{}, std::vector<float>{1.0f});
auto broadcast = make_shared<op::Broadcast>(one, shapeC, AxisSet{0, 1});
auto cg =
make_shared<op::MatmulBias>(A, B, broadcast, A->get_shape(), B->get_shape(), true, true);
auto f = make_shared<Function>(cg, op::Parameters{A, B});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
shared_ptr<runtime::TensorView> a = backend->make_primary_tensor_view(element::f32, shapeA);
shared_ptr<runtime::TensorView> b = backend->make_primary_tensor_view(element::f32, shapeB);
shared_ptr<runtime::TensorView> result =
backend->make_primary_tensor_view(element::f32, shapeC);
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f, 1.0f, 4.0f};
vector<float> dataB{3.0f, 3.0f, 3.0f, 9.0f, 9.0f, 9.0f};
copy_data(a, dataA);
copy_data(b, dataB);
cf->call({a, b}, {result});
vector<float> expected{10, 28, 37, 109};
ASSERT_TRUE(read_vector<float>(result) == expected);
}
TEST(cpu_fusion, cpu_fusion_pass_basic)
{
auto shape = Shape{};
auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1};
auto shape_b = Shape{1};
auto A = make_shared<op::Parameter>(element::f32, shape_w);
auto B = make_shared<op::Parameter>(element::f32, shape_x);
auto C = make_shared<op::Parameter>(element::f32, shape_b);
auto dot = make_shared<op::Dot>(A, B);
auto broadcast = make_shared<op::Broadcast>(C, dot->get_shape(), AxisSet{0});
auto add = dot + broadcast;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CPUFusion>();
auto func = make_shared<Function>(graph, op::Parameters{A, B, C});
pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr);
}
TEST(cpu_fusion, gemm_mlp)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/mnist_mlp_forward.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CPUFusion>();
pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(ccg, 3);
}
...@@ -28,47 +28,11 @@ ...@@ -28,47 +28,11 @@
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp" #include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "util/matcher.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
//this is for more nuanced testing
class TestMatcher : public pattern::Matcher
{
using pattern::Matcher::Matcher;
bool virtual match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map) override
{
if (std::dynamic_pointer_cast<::ngraph::op::Parameter>(pattern_node))
{
return pattern_node.get() == dynamic_cast<::ngraph::op::Parameter*>(graph_node.get());
}
return this->pattern::Matcher::match_node(pattern_node, graph_node, pattern_map);
}
public:
bool match(const std::shared_ptr<Node>& pattern_node, const std::shared_ptr<Node>& graph_node)
{
assert(
pattern_node &&
graph_node); //the same condition throws an exception in the non-test version of `match`
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
m_pattern_map.clear();
m_match_root.reset();
bool is_match = match_node(pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
}
};
template <typename T> template <typename T>
std::shared_ptr<Node> create_reduction(const std::shared_ptr<Node>& node, std::shared_ptr<Node> create_reduction(const std::shared_ptr<Node>& node,
const std::string& init_val, const std::string& init_val,
...@@ -181,13 +145,13 @@ public: ...@@ -181,13 +145,13 @@ public:
auto second_node = m.match_root()->get_input_ops().at(const_node_index); auto second_node = m.match_root()->get_input_ops().at(const_node_index);
NGRAPH_DEBUG << "second_node = " << second_node->get_name() NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name(); << " , pattern = " << pattern_map[pattern]->get_name();
ASSERT_TRUE(const_node);
std::shared_ptr<ngraph::Node> nn = nullptr;
if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() || if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
pattern_map[pattern]->get_shape() != const_node->get_shape()) pattern_map[pattern]->get_shape() != const_node->get_shape())
{ {
NGRAPH_DEBUG << "Operands' types and/or shape don't match"; NGRAPH_DEBUG << "Operands' types and/or shape don't match";
return; return nn;
} }
auto const_values = const_node->get_vector<int32_t>(); auto const_values = const_node->get_vector<int32_t>();
...@@ -197,9 +161,9 @@ public: ...@@ -197,9 +161,9 @@ public:
if (!all_ones) if (!all_ones)
{ {
NGRAPH_DEBUG << "Constant vector's values aren't equal to 1"; NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
return; return nn;
} }
ngraph::replace_node(m.match_root(), pattern_map[pattern]); return pattern_map[pattern];
}; };
auto m = make_shared<TestMatcher>(pattern * iconst1, callback); auto m = make_shared<TestMatcher>(pattern * iconst1, callback);
...@@ -212,7 +176,7 @@ public: ...@@ -212,7 +176,7 @@ public:
auto iconst0 = construct_constant_node(0); auto iconst0 = construct_constant_node(0);
auto pattern = std::make_shared<pattern::op::Label>(iconst0); auto pattern = std::make_shared<pattern::op::Label>(iconst0);
ngraph::pattern::gr_callback_fn callback = [pattern](pattern::Matcher& m) { auto callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_add_zero against " NGRAPH_DEBUG << "In a callback for construct_add_zero against "
<< m.match_root()->get_name(); << m.match_root()->get_name();
assert(m.match_root()->get_input_ops().size() == 2); assert(m.match_root()->get_input_ops().size() == 2);
...@@ -225,13 +189,15 @@ public: ...@@ -225,13 +189,15 @@ public:
auto second_node = m.match_root()->get_input_ops().at(const_node_index); auto second_node = m.match_root()->get_input_ops().at(const_node_index);
NGRAPH_DEBUG << "second_node = " << second_node->get_name() NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name(); << " , pattern = " << pattern_map[pattern]->get_name();
ASSERT_NE(nullptr, const_node);
//ASSERT_NE(nullptr, const_node);
std::shared_ptr<ngraph::Node> nn = nullptr;
if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() || if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
pattern_map[pattern]->get_shape() != const_node->get_shape()) pattern_map[pattern]->get_shape() != const_node->get_shape())
{ {
NGRAPH_DEBUG << "Operands' types and/or shape don't match"; NGRAPH_DEBUG << "Operands' types and/or shape don't match";
return; return nn;
} }
auto const_values = const_node->get_vector<int>(); auto const_values = const_node->get_vector<int>();
...@@ -241,10 +207,10 @@ public: ...@@ -241,10 +207,10 @@ public:
if (!all_zeros) if (!all_zeros)
{ {
NGRAPH_DEBUG << "Constant vector's values aren't equal to 0"; NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
return; return nn;
} }
ngraph::replace_node(m.match_root(), pattern_map[pattern]); return pattern_map[pattern];
}; };
auto m = make_shared<TestMatcher>(pattern + iconst0, callback); auto m = make_shared<TestMatcher>(pattern + iconst0, callback);
...@@ -261,8 +227,9 @@ public: ...@@ -261,8 +227,9 @@ public:
auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.match_root()); auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.match_root());
auto reducee = reduce->get_inputs().at(0).get_output().get_node(); auto reducee = reduce->get_inputs().at(0).get_output().get_node();
NGRAPH_DEBUG << "reducee = " << reducee->get_name(); NGRAPH_DEBUG << "reducee = " << reducee->get_name();
auto sum = std::make_shared<op::Sum>(reducee, reduce->get_reduction_axes()); auto sum =
ngraph::replace_node(reduce, sum); std::shared_ptr<ngraph::Node>(new op::Sum(reducee, reduce->get_reduction_axes()));
return sum;
}; };
auto m = make_shared<TestMatcher>(sum_pattern, callback); auto m = make_shared<TestMatcher>(sum_pattern, callback);
...@@ -290,9 +257,27 @@ TEST(pattern, graph_rewrite) ...@@ -290,9 +257,27 @@ TEST(pattern, graph_rewrite)
{ {
auto shape = Shape{}; auto shape = Shape{};
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<TestGraphRewrite>(); pass_manager.register_pass<TestGraphRewrite>();
{
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto c = make_shared<op::Parameter>(element::i32, shape);
auto iconst0 = construct_constant_node(0);
auto graph_a = a + iconst0;
auto graph_b = b + iconst0;
auto f = std::make_shared<Function>(ngraph::Nodes{a, b, graph_a, c, graph_b},
op::Parameters{a, b, c});
pass_manager.run_passes(f);
ASSERT_TRUE(graph_a->get_output_inputs(0).empty());
ASSERT_TRUE(graph_b->get_output_inputs(0).empty());
auto expected = ngraph::Nodes{a, b, a, c, b};
ASSERT_TRUE(f->get_results() == expected);
}
{ {
auto a = make_shared<op::Parameter>(element::i32, shape); auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape); auto b = make_shared<op::Parameter>(element::i32, shape);
......
This diff is collapsed.
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
//this is for more nuanced testing
class TestMatcher : public ngraph::pattern::Matcher
{
using ngraph::pattern::Matcher::Matcher;
bool virtual match_node(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node,
PatternMap& pattern_map) override
{
if (std::dynamic_pointer_cast<::ngraph::op::Parameter>(pattern_node))
{
return pattern_node.get() == dynamic_cast<::ngraph::op::Parameter*>(graph_node.get());
}
return this->ngraph::pattern::Matcher::match_node(pattern_node, graph_node, pattern_map);
}
public:
bool match(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node)
{
assert(
pattern_node &&
graph_node); //the same condition throws an exception in the non-test version of `match`
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
m_pattern_map.clear();
m_match_root.reset();
bool is_match = match_node(pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
}
};
template <typename T>
size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f)
{
size_t count = 0;
for (auto op : f->get_ops())
{
if (std::dynamic_pointer_cast<T>(op))
{
count++;
}
}
return count;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment