Commit 195b1aa8 authored by mbencer's avatar mbencer

Merge branch 'mbencer/BuilderSplitV1' of…

Merge branch 'mbencer/BuilderSplitV1' of https://github.com/NervanaSystems/ngraph into mbencer/BuilderSplitV1
parents 68698d51 ad3a67dc
...@@ -143,7 +143,7 @@ option(NGRAPH_PLAIDML_STATIC_LIB_ENABLE "Enable build PlaidML backend as a stati ...@@ -143,7 +143,7 @@ option(NGRAPH_PLAIDML_STATIC_LIB_ENABLE "Enable build PlaidML backend as a stati
option(NGRAPH_DYNAMIC_COMPONENTS_ENABLE "Enable dynamic loading of components" TRUE) option(NGRAPH_DYNAMIC_COMPONENTS_ENABLE "Enable dynamic loading of components" TRUE)
option(NGRAPH_NATIVE_ARCH_ENABLE "Enable build for native archtecture" TRUE) option(NGRAPH_NATIVE_ARCH_ENABLE "Enable build for native archtecture" TRUE)
option(NGRAPH_EXPORT_TARGETS_ENABLE "Enable exporting nGraph cmake export targets" TRUE) option(NGRAPH_EXPORT_TARGETS_ENABLE "Enable exporting nGraph cmake export targets" TRUE)
option(NGRAPH_WARNINGS_AS_ERRORS "Make all nGraph compile-time warnings into errors" TRUE) option(NGRAPH_WARNINGS_AS_ERRORS "Make all nGraph compile-time warnings into errors" FALSE)
if (NGRAPH_CPU_ENABLE) if (NGRAPH_CPU_ENABLE)
option(NGRAPH_TBB_ENABLE "Control usage of TBB for CPU backend" TRUE) option(NGRAPH_TBB_ENABLE "Control usage of TBB for CPU backend" TRUE)
endif() endif()
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
/doc/ @indie /doc/ @indie
/doc/examples/mnist_mlp/dist_* @wenzhe-nrv @indie /doc/examples/mnist_mlp/dist_* @wenzhe-nrv @indie
/doc/*/*/frameworks/tensorflow_connect.rst @shresthamalik @sayantan-nervana /doc/*/*/frameworks/tensorflow_connect.rst @shresthamalik @sayantan-nervana
/doc/*/*/backends/plaidml-ng-api/ @flaub @brianretford @dgkutnic /doc/*/*/backends/plaidml-ng-api/ @diyessi
/doc/*/*/inspection/ @diyessi /doc/*/*/inspection/ @diyessi
/doc/examples/onnx/ @arogowie-intel @tsocha /doc/examples/onnx/ @arogowie-intel @tsocha
/README.md @indie /README.md @indie
...@@ -55,7 +55,7 @@ project/doc-contributor-README.rst @indie ...@@ -55,7 +55,7 @@ project/doc-contributor-README.rst @indie
/src/ngraph/runtime/dynamic/ @diyessi /src/ngraph/runtime/dynamic/ @diyessi
/src/ngraph/runtime/gpu/ @csullivan @rkimballn1 /src/ngraph/runtime/gpu/ @csullivan @rkimballn1
/src/ngraph/runtime/interpreter/ @rkimballn1 /src/ngraph/runtime/interpreter/ @rkimballn1
/src/ngraph/runtime/plaidml/ @earhart @dgkutnic /src/ngraph/runtime/plaidml/ @diyessi
/src/ngraph/runtime/reference/ @diyessi /src/ngraph/runtime/reference/ @diyessi
/src/ngraph/runtime/reference/allreduce.*pp @wenzhe-nrv @diyessi /src/ngraph/runtime/reference/allreduce.*pp @wenzhe-nrv @diyessi
/src/ngraph/type/ @diyessi /src/ngraph/type/ @diyessi
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "mlir_subgraph_extraction.hpp" #include "mlir_subgraph_extraction.hpp"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/ops.hpp" #include "ngraph/ops.hpp"
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <sstream> #include <sstream>
#include "ngraph/env_util.hpp" #include "ngraph/env_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "matmul_integer.hpp" #include "matmul_integer.hpp"
#include "ngraph/builder/matmul_factory.hpp" #include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/log.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "qlinear_matmul.hpp" #include "qlinear_matmul.hpp"
#include "ngraph/builder/matmul_factory.hpp" #include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/log.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <mutex> #include <mutex>
#include <thread> #include <thread>
#include "ngraph/distributed.hpp"
#include "ngraph/env_util.hpp" #include "ngraph/env_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
......
...@@ -30,8 +30,6 @@ ...@@ -30,8 +30,6 @@
#endif #endif
#include <vector> #include <vector>
#include "ngraph/distributed.hpp"
namespace ngraph namespace ngraph
{ {
class ConstString class ConstString
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include "ngraph/distributed.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <sstream> #include <sstream>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/binary_convolution.hpp" #include "ngraph/op/binary_convolution.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/axis_vector.hpp" #include "ngraph/axis_vector.hpp"
#include "ngraph/coordinate_diff.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/broadcast_distributed.hpp" #include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/attribute_visitor.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
......
...@@ -32,20 +32,20 @@ op::Squeeze::Squeeze(const Output<Node>& data, const Output<Node>& axes) ...@@ -32,20 +32,20 @@ op::Squeeze::Squeeze(const Output<Node>& data, const Output<Node>& axes)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
NodeVector op::Squeeze::decompose_op() const void op::Squeeze::pre_validate_and_infer_types()
{ {
auto data = input_value(0); auto data = input_value(0);
auto axes_node = input_value(1).get_node_shared_ptr(); auto axes_node = input_value(1).get_node_shared_ptr();
// Currently only support Constant node for axes. if (data.get_partial_shape().is_dynamic() || !axes_node->is_constant())
NODE_VALIDATION_CHECK(this, {
axes_node->is_constant(), set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
"doesn't support 'axes' input of other type than a Constant."); return;
}
// Get value of axes from Constant // Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node); auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->cast_vector<size_t>(); auto axes = axes_constant->cast_vector<size_t>();
auto data_shape = data.get_shape(); auto data_shape = data.get_shape();
std::vector<uint64_t> axes_to_squeeze(data_shape.size()); std::vector<uint64_t> axes_to_squeeze(data_shape.size());
...@@ -87,6 +87,18 @@ NodeVector op::Squeeze::decompose_op() const ...@@ -87,6 +87,18 @@ NodeVector op::Squeeze::decompose_op() const
} }
} }
set_output_type(0, get_input_element_type(0), output_data_shape);
}
NodeVector op::Squeeze::decompose_op() const
{
NODE_VALIDATION_CHECK(
this,
(get_output_partial_shape(0).is_static()),
"output shape was not calculated during pre_validate_and_infer_types. Can not decompose.");
auto data = input_value(0);
auto data_shape = data.get_shape();
auto output_data_shape = get_output_shape(0);
AxisVector input_order{get_default_order(data_shape.size())}; AxisVector input_order{get_default_order(data_shape.size())};
return {make_shared<op::Reshape>(data, input_order, output_data_shape)}; return {make_shared<op::Reshape>(data, input_order, output_data_shape)};
} }
......
...@@ -38,6 +38,7 @@ namespace ngraph ...@@ -38,6 +38,7 @@ namespace ngraph
Squeeze(const Output<Node>& data, const Output<Node>& axes); Squeeze(const Output<Node>& data, const Output<Node>& axes);
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -33,31 +33,35 @@ op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes) ...@@ -33,31 +33,35 @@ op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes)
} }
void op::Unsqueeze::pre_validate_and_infer_types() void op::Unsqueeze::pre_validate_and_infer_types()
{
auto axes_node = input_value(1).get_node_shared_ptr();
// Currently only support Constant node for axes.
NODE_VALIDATION_CHECK(this,
axes_node->is_constant(),
"doesn't support 'axes' input of other type than a Constant.");
}
NodeVector op::Unsqueeze::decompose_op() const
{ {
auto data = input_value(0); auto data = input_value(0);
auto axes_node = input_value(1).get_node_shared_ptr(); auto axes_node = input_value(1).get_node_shared_ptr();
if (data.get_partial_shape().rank().is_dynamic() || !axes_node->is_constant())
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
return;
}
// Get value of axes from Constant // Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node); auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->cast_vector<size_t>(); auto axes = axes_constant->cast_vector<size_t>();
auto data_shape = data.get_shape();
NODE_VALIDATION_CHECK(this, !axes.empty(), "'axes' input is mandatory."); NODE_VALIDATION_CHECK(this, !axes.empty(), "'axes' input is mandatory.");
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
axes.size() == set<int64_t>(begin(axes), end(axes)).size(), axes.size() == set<int64_t>(begin(axes), end(axes)).size(),
"'axes' input has a duplicate axis."); "'axes' input has a duplicate axis.");
if (data.get_partial_shape().is_dynamic())
{
set_output_type(0,
get_input_element_type(0),
PartialShape::dynamic(data.get_partial_shape().rank() + axes.size()));
return;
}
auto data_shape = data.get_shape();
sort(begin(axes), end(axes), less<int64_t>()); sort(begin(axes), end(axes), less<int64_t>());
AxisVector input_order{ngraph::get_default_order(data_shape.size())}; AxisVector input_order{ngraph::get_default_order(data_shape.size())};
...@@ -69,8 +73,20 @@ NodeVector op::Unsqueeze::decompose_op() const ...@@ -69,8 +73,20 @@ NodeVector op::Unsqueeze::decompose_op() const
data_shape.insert(next(begin(data_shape), axis), 1); data_shape.insert(next(begin(data_shape), axis), 1);
} }
set_output_type(0, get_input_element_type(0), data_shape);
}
return {make_shared<ngraph::op::Reshape>(data, input_order, data_shape)}; NodeVector op::Unsqueeze::decompose_op() const
{
NODE_VALIDATION_CHECK(
this,
(get_output_partial_shape(0).is_static()),
"output shape was not calculated during pre_validate_and_infer_types. Can not decompose.");
auto data = input_value(0);
auto data_shape = data.get_shape();
auto output_shape = get_output_shape(0);
AxisVector input_order{ngraph::get_default_order(data_shape.size())};
return {make_shared<ngraph::op::Reshape>(data, input_order, output_shape)};
} }
shared_ptr<Node> op::Unsqueeze::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Unsqueeze::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp" #include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/attribute_visitor.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/util/binary_elementwise_comparison.hpp" #include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/attribute_visitor.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/util/binary_elementwise_logical.hpp" #include "ngraph/op/util/binary_elementwise_logical.hpp"
#include "ngraph/attribute_visitor.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/util/index_reduction.hpp" #include "ngraph/op/util/index_reduction.hpp"
using namespace std; using namespace std;
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "batch_fusion.hpp" #include "batch_fusion.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#pragma once #pragma once
#include "ngraph/log.hpp"
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/runtime/aligned_buffer.hpp" #include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/pass/pass_util.hpp" #include "ngraph/pass/pass_util.hpp"
#include "ngraph/log.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "ngraph/pass/propagate_cacheability.hpp" #include "ngraph/pass/propagate_cacheability.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
#include "ngraph/op/util/op_annotations.hpp" #include "ngraph/op/util/op_annotations.hpp"
......
...@@ -22,8 +22,6 @@ ...@@ -22,8 +22,6 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "ngraph/log.hpp"
namespace ngraph namespace ngraph
{ {
enum class Placement enum class Placement
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/broadcast_distributed.hpp" #include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/distributed.hpp"
#include "ngraph/log.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp" #include "ngraph/runtime/cpu/cpu_builder.hpp"
using namespace std; using namespace std;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "ngraph/log.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp" #include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/softmax.hpp" #include "ngraph/runtime/cpu/kernel/softmax.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp" #include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ngraph/component_manager.hpp" #include "ngraph/component_manager.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/runtime/backend_manager.hpp" #include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/cpu/cpu_backend.hpp" #include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/runtime/cpu/cpu_builder_registry.hpp" #include "ngraph/runtime/cpu/cpu_builder_registry.hpp"
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "cpu_cse.hpp" #include "cpu_cse.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp" #include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <string> #include <string>
#include "ngraph/log.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp" #include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "deconv.hpp" #include "deconv.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <mkldnn.hpp> #include <mkldnn.hpp>
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "cpu_mat_fusion.hpp" #include "cpu_mat_fusion.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
......
...@@ -53,6 +53,7 @@ ...@@ -53,6 +53,7 @@
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <typeindex> #include <typeindex>
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "ngraph/op/equal.hpp" #include "ngraph/op/equal.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "ngraph/env_util.hpp" #include "ngraph/env_util.hpp"
#include "ngraph/file_util.hpp" #include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ops.hpp" #include "ngraph/ops.hpp"
#include "ngraph/provenance.hpp" #include "ngraph/provenance.hpp"
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/log.hpp"
#include "ngraph/runtime/aligned_buffer.hpp" #include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/type/bfloat16.hpp" #include "ngraph/type/bfloat16.hpp"
#include "util/float_util.hpp" #include "util/float_util.hpp"
......
...@@ -37,3 +37,21 @@ TEST(type_prop, squeeze) ...@@ -37,3 +37,21 @@ TEST(type_prop, squeeze)
ASSERT_EQ(squeeze_default_axes->get_element_type(), element::f32); ASSERT_EQ(squeeze_default_axes->get_element_type(), element::f32);
ASSERT_EQ(squeeze_default_axes->get_shape(), (Shape{4, 4, 8})); ASSERT_EQ(squeeze_default_axes->get_shape(), (Shape{4, 4, 8}));
} }
TEST(type_prop, squeeze_dynamic)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(6));
auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{0, 2});
auto squeeze = make_shared<op::Squeeze>(param, axes_node);
ASSERT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_TRUE(squeeze->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
axes_node = make_shared<ngraph::op::Constant>(element::u64, Shape{0}, vector<int64_t>{});
auto squeeze_default_axes = make_shared<op::Squeeze>(param, axes_node);
ASSERT_EQ(squeeze_default_axes->get_element_type(), element::f32);
EXPECT_TRUE(
squeeze_default_axes->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
...@@ -26,8 +26,19 @@ TEST(type_prop, unsqueeze) ...@@ -26,8 +26,19 @@ TEST(type_prop, unsqueeze)
auto param = make_shared<op::Parameter>(element::f32, Shape{4, 1, 4, 1, 8}); auto param = make_shared<op::Parameter>(element::f32, Shape{4, 1, 4, 1, 8});
auto axes_node = auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{1, 2}); make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{1, 2});
auto squeeze = make_shared<op::Unsqueeze>(param, axes_node); auto unsqueeze = make_shared<op::Unsqueeze>(param, axes_node);
ASSERT_EQ(squeeze->get_element_type(), element::f32); ASSERT_EQ(unsqueeze->get_element_type(), element::f32);
ASSERT_EQ(squeeze->get_shape(), (Shape{4, 1, 1, 1, 4, 1, 8})); ASSERT_EQ(unsqueeze->get_shape(), (Shape{4, 1, 1, 1, 4, 1, 8}));
}
TEST(type_prop, unsqueeze_dynamic)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(5));
auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{1, 2});
auto unsqueeze = make_shared<op::Unsqueeze>(param, axes_node);
ASSERT_EQ(unsqueeze->get_element_type(), element::f32);
EXPECT_TRUE(unsqueeze->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(7)));
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment