Commit 195b1aa8 authored by mbencer's avatar mbencer

Merge branch 'mbencer/BuilderSplitV1' of…

Merge branch 'mbencer/BuilderSplitV1' of https://github.com/NervanaSystems/ngraph into mbencer/BuilderSplitV1
parents 68698d51 ad3a67dc
......@@ -143,7 +143,7 @@ option(NGRAPH_PLAIDML_STATIC_LIB_ENABLE "Enable build PlaidML backend as a stati
option(NGRAPH_DYNAMIC_COMPONENTS_ENABLE "Enable dynamic loading of components" TRUE)
option(NGRAPH_NATIVE_ARCH_ENABLE "Enable build for native archtecture" TRUE)
option(NGRAPH_EXPORT_TARGETS_ENABLE "Enable exporting nGraph cmake export targets" TRUE)
option(NGRAPH_WARNINGS_AS_ERRORS "Make all nGraph compile-time warnings into errors" TRUE)
option(NGRAPH_WARNINGS_AS_ERRORS "Make all nGraph compile-time warnings into errors" FALSE)
if (NGRAPH_CPU_ENABLE)
option(NGRAPH_TBB_ENABLE "Control usage of TBB for CPU backend" TRUE)
endif()
......
......@@ -20,7 +20,7 @@
/doc/ @indie
/doc/examples/mnist_mlp/dist_* @wenzhe-nrv @indie
/doc/*/*/frameworks/tensorflow_connect.rst @shresthamalik @sayantan-nervana
/doc/*/*/backends/plaidml-ng-api/ @flaub @brianretford @dgkutnic
/doc/*/*/backends/plaidml-ng-api/ @diyessi
/doc/*/*/inspection/ @diyessi
/doc/examples/onnx/ @arogowie-intel @tsocha
/README.md @indie
......@@ -55,7 +55,7 @@ project/doc-contributor-README.rst @indie
/src/ngraph/runtime/dynamic/ @diyessi
/src/ngraph/runtime/gpu/ @csullivan @rkimballn1
/src/ngraph/runtime/interpreter/ @rkimballn1
/src/ngraph/runtime/plaidml/ @earhart @dgkutnic
/src/ngraph/runtime/plaidml/ @diyessi
/src/ngraph/runtime/reference/ @diyessi
/src/ngraph/runtime/reference/allreduce.*pp @wenzhe-nrv @diyessi
/src/ngraph/type/ @diyessi
......
......@@ -20,6 +20,7 @@
#include "mlir_subgraph_extraction.hpp"
#include "ngraph/assertion.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops.hpp"
......
......@@ -17,6 +17,7 @@
#include <sstream>
#include "ngraph/env_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
using namespace std;
......
......@@ -16,6 +16,7 @@
#include "matmul_integer.hpp"
#include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/log.hpp"
namespace ngraph
{
......
......@@ -16,6 +16,7 @@
#include "qlinear_matmul.hpp"
#include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/log.hpp"
namespace ngraph
{
......
......@@ -22,6 +22,7 @@
#include <mutex>
#include <thread>
#include "ngraph/distributed.hpp"
#include "ngraph/env_util.hpp"
#include "ngraph/log.hpp"
......
......@@ -30,8 +30,6 @@
#endif
#include <vector>
#include "ngraph/distributed.hpp"
namespace ngraph
{
class ConstString
......
......@@ -17,6 +17,7 @@
#pragma once
#include <memory>
#include "ngraph/distributed.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/validation_util.hpp"
......
......@@ -16,6 +16,7 @@
#include <sstream>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp"
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/binary_convolution.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/axis_vector.hpp"
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/op/reshape.hpp"
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/broadcast.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/util/broadcasting.hpp"
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/attribute_visitor.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -16,6 +16,7 @@
#include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/slice.hpp"
......
......@@ -32,20 +32,20 @@ op::Squeeze::Squeeze(const Output<Node>& data, const Output<Node>& axes)
constructor_validate_and_infer_types();
}
NodeVector op::Squeeze::decompose_op() const
void op::Squeeze::pre_validate_and_infer_types()
{
auto data = input_value(0);
auto axes_node = input_value(1).get_node_shared_ptr();
// Currently only support Constant node for axes.
NODE_VALIDATION_CHECK(this,
axes_node->is_constant(),
"doesn't support 'axes' input of other type than a Constant.");
if (data.get_partial_shape().is_dynamic() || !axes_node->is_constant())
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
return;
}
// Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->cast_vector<size_t>();
auto data_shape = data.get_shape();
std::vector<uint64_t> axes_to_squeeze(data_shape.size());
......@@ -87,6 +87,18 @@ NodeVector op::Squeeze::decompose_op() const
}
}
set_output_type(0, get_input_element_type(0), output_data_shape);
}
NodeVector op::Squeeze::decompose_op() const
{
NODE_VALIDATION_CHECK(
this,
(get_output_partial_shape(0).is_static()),
"output shape was not calculated during pre_validate_and_infer_types. Can not decompose.");
auto data = input_value(0);
auto data_shape = data.get_shape();
auto output_data_shape = get_output_shape(0);
AxisVector input_order{get_default_order(data_shape.size())};
return {make_shared<op::Reshape>(data, input_order, output_data_shape)};
}
......
......@@ -38,6 +38,7 @@ namespace ngraph
Squeeze(const Output<Node>& data, const Output<Node>& axes);
virtual NodeVector decompose_op() const override;
virtual void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -33,31 +33,35 @@ op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes)
}
void op::Unsqueeze::pre_validate_and_infer_types()
{
auto axes_node = input_value(1).get_node_shared_ptr();
// Currently only support Constant node for axes.
NODE_VALIDATION_CHECK(this,
axes_node->is_constant(),
"doesn't support 'axes' input of other type than a Constant.");
}
NodeVector op::Unsqueeze::decompose_op() const
{
auto data = input_value(0);
auto axes_node = input_value(1).get_node_shared_ptr();
if (data.get_partial_shape().rank().is_dynamic() || !axes_node->is_constant())
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
return;
}
// Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->cast_vector<size_t>();
auto data_shape = data.get_shape();
NODE_VALIDATION_CHECK(this, !axes.empty(), "'axes' input is mandatory.");
NODE_VALIDATION_CHECK(this,
axes.size() == set<int64_t>(begin(axes), end(axes)).size(),
"'axes' input has a duplicate axis.");
if (data.get_partial_shape().is_dynamic())
{
set_output_type(0,
get_input_element_type(0),
PartialShape::dynamic(data.get_partial_shape().rank() + axes.size()));
return;
}
auto data_shape = data.get_shape();
sort(begin(axes), end(axes), less<int64_t>());
AxisVector input_order{ngraph::get_default_order(data_shape.size())};
......@@ -69,8 +73,20 @@ NodeVector op::Unsqueeze::decompose_op() const
data_shape.insert(next(begin(data_shape), axis), 1);
}
set_output_type(0, get_input_element_type(0), data_shape);
}
return {make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
NodeVector op::Unsqueeze::decompose_op() const
{
NODE_VALIDATION_CHECK(
this,
(get_output_partial_shape(0).is_static()),
"output shape was not calculated during pre_validate_and_infer_types. Can not decompose.");
auto data = input_value(0);
auto data_shape = data.get_shape();
auto output_shape = get_output_shape(0);
AxisVector input_order{ngraph::get_default_order(data_shape.size())};
return {make_shared<ngraph::op::Reshape>(data, input_order, output_shape)};
}
shared_ptr<Node> op::Unsqueeze::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -16,6 +16,7 @@
#include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/multiply.hpp"
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/attribute_visitor.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/attribute_visitor.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/binary_elementwise_logical.hpp"
#include "ngraph/attribute_visitor.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -16,6 +16,7 @@
#include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/util/index_reduction.hpp"
using namespace std;
......
......@@ -25,6 +25,7 @@
#include "batch_fusion.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
......
......@@ -16,6 +16,7 @@
#pragma once
#include "ngraph/log.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/util.hpp"
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/pass/pass_util.hpp"
#include "ngraph/log.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -17,6 +17,7 @@
#include "ngraph/pass/propagate_cacheability.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/util/op_annotations.hpp"
......
......@@ -22,8 +22,6 @@
#include <unordered_set>
#include <vector>
#include "ngraph/log.hpp"
namespace ngraph
{
enum class Placement
......
......@@ -15,6 +15,8 @@
//*****************************************************************************
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/distributed.hpp"
#include "ngraph/log.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
using namespace std;
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/softmax.hpp"
#include "ngraph/log.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/softmax.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
......
......@@ -22,6 +22,7 @@
#include "ngraph/component_manager.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/runtime/cpu/cpu_builder_registry.hpp"
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "cpu_cse.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
......
......@@ -16,6 +16,7 @@
#include <string>
#include "ngraph/log.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
......
......@@ -18,6 +18,7 @@
#include "deconv.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/util.hpp"
......
......@@ -24,6 +24,7 @@
#include <mkldnn.hpp>
#include "ngraph/descriptor/output.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
......
......@@ -24,6 +24,7 @@
#include "cpu_mat_fusion.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
......
......@@ -53,6 +53,7 @@
#include "ngraph/descriptor/output.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/slice.hpp"
......
......@@ -17,6 +17,7 @@
#include <typeindex>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/greater.hpp"
......
......@@ -23,6 +23,7 @@
#include "ngraph/env_util.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/provenance.hpp"
#include "ngraph/serializer.hpp"
......
......@@ -19,6 +19,7 @@
#include "gtest/gtest.h"
#include "ngraph/log.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/type/bfloat16.hpp"
#include "util/float_util.hpp"
......
......@@ -37,3 +37,21 @@ TEST(type_prop, squeeze)
ASSERT_EQ(squeeze_default_axes->get_element_type(), element::f32);
ASSERT_EQ(squeeze_default_axes->get_shape(), (Shape{4, 4, 8}));
}
TEST(type_prop, squeeze_dynamic)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(6));
auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{0, 2});
auto squeeze = make_shared<op::Squeeze>(param, axes_node);
ASSERT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_TRUE(squeeze->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
axes_node = make_shared<ngraph::op::Constant>(element::u64, Shape{0}, vector<int64_t>{});
auto squeeze_default_axes = make_shared<op::Squeeze>(param, axes_node);
ASSERT_EQ(squeeze_default_axes->get_element_type(), element::f32);
EXPECT_TRUE(
squeeze_default_axes->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
......@@ -26,8 +26,19 @@ TEST(type_prop, unsqueeze)
auto param = make_shared<op::Parameter>(element::f32, Shape{4, 1, 4, 1, 8});
auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{1, 2});
auto squeeze = make_shared<op::Unsqueeze>(param, axes_node);
auto unsqueeze = make_shared<op::Unsqueeze>(param, axes_node);
ASSERT_EQ(squeeze->get_element_type(), element::f32);
ASSERT_EQ(squeeze->get_shape(), (Shape{4, 1, 1, 1, 4, 1, 8}));
ASSERT_EQ(unsqueeze->get_element_type(), element::f32);
ASSERT_EQ(unsqueeze->get_shape(), (Shape{4, 1, 1, 1, 4, 1, 8}));
}
TEST(type_prop, unsqueeze_dynamic)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(5));
auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{1, 2});
auto unsqueeze = make_shared<op::Unsqueeze>(param, axes_node);
ASSERT_EQ(unsqueeze->get_element_type(), element::f32);
EXPECT_TRUE(unsqueeze->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(7)));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment