Commit c7b51d2d authored by Robert Kimball's avatar Robert Kimball

apply new .clang-format

parent 158de495
...@@ -27,7 +27,6 @@ namespace ngraph ...@@ -27,7 +27,6 @@ namespace ngraph
{ {
public: public:
size_t size() const { return m_size; } size_t size() const { return m_size; }
protected: protected:
size_t m_size; size_t m_size;
}; };
......
...@@ -29,7 +29,6 @@ namespace ngraph ...@@ -29,7 +29,6 @@ namespace ngraph
{ {
public: public:
BufferPos() {} BufferPos() {}
BufferPos(std::shared_ptr<Buffer> buffer, size_t offset, size_t size) BufferPos(std::shared_ptr<Buffer> buffer, size_t offset, size_t size)
: m_buffer(buffer) : m_buffer(buffer)
, m_offset(offset) , m_offset(offset)
......
...@@ -40,7 +40,6 @@ namespace ngraph ...@@ -40,7 +40,6 @@ namespace ngraph
virtual size_t get_index_offset(const std::vector<size_t>& indices) override; virtual size_t get_index_offset(const std::vector<size_t>& indices) override;
const Strides& get_strides() const { return m_strides; } const Strides& get_strides() const { return m_strides; }
protected: protected:
Strides m_strides; Strides m_strides;
size_t m_offset; size_t m_offset;
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/buffer_pos.hpp" #include "ngraph/descriptor/buffer_pos.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -41,7 +41,6 @@ namespace ngraph ...@@ -41,7 +41,6 @@ namespace ngraph
public: public:
virtual ~TensorViewLayout() {} virtual ~TensorViewLayout() {}
/// Extent of this view in buffer. /// Extent of this view in buffer.
/// ///
/// When we support non-linear buffers, this will need to be something other than size_t. /// When we support non-linear buffers, this will need to be something other than size_t.
...@@ -52,12 +51,14 @@ namespace ngraph ...@@ -52,12 +51,14 @@ namespace ngraph
/// With non-linear buffers, this will need to be something other than size_t. /// With non-linear buffers, this will need to be something other than size_t.
virtual size_t get_index_offset(const std::vector<size_t>& indices) = 0; virtual size_t get_index_offset(const std::vector<size_t>& indices) = 0;
const Shape& get_shape() const { return m_tensor_view.get_tensor_view_type()->get_shape(); } const Shape& get_shape() const
{
return m_tensor_view.get_tensor_view_type()->get_shape();
}
/// Where this view is located in the buffer. /// Where this view is located in the buffer.
const BufferPos& get_buffer_pos() const { return m_buffer_pos; } const BufferPos& get_buffer_pos() const { return m_buffer_pos; }
BufferPos& get_buffer_pos() { return m_buffer_pos; } BufferPos& get_buffer_pos() { return m_buffer_pos; }
protected: protected:
const ngraph::descriptor::TensorView& m_tensor_view; const ngraph::descriptor::TensorView& m_tensor_view;
BufferPos m_buffer_pos; BufferPos m_buffer_pos;
......
...@@ -57,7 +57,6 @@ namespace ngraph ...@@ -57,7 +57,6 @@ namespace ngraph
} }
const std::string& get_name() const { return m_name; } const std::string& get_name() const { return m_name; }
std::shared_ptr<const TensorViewType> get_tensor_view_type() const std::shared_ptr<const TensorViewType> get_tensor_view_type() const
{ {
return m_tensor_view_type; return m_tensor_view_type;
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include <initializer_list> #include <initializer_list>
#include <memory> #include <memory>
#include <vector>
#include <string> #include <string>
#include <vector>
#include "ngraph/descriptor/tensor_view.hpp" #include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
...@@ -41,10 +41,7 @@ namespace ngraph ...@@ -41,10 +41,7 @@ namespace ngraph
{ {
return m_parameters; return m_parameters;
} }
const std::shared_ptr<ValueType> get_result_type() const const std::shared_ptr<ValueType> get_result_type() const { return m_result_type; }
{
return m_result_type;
}
std::string get_name() const { return m_name; } std::string get_name() const { return m_name; }
protected: protected:
std::shared_ptr<Node> m_result; std::shared_ptr<Node> m_result;
......
...@@ -32,7 +32,9 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType> ...@@ -32,7 +32,9 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType>
} }
} }
Node::~Node() {} Node::~Node()
{
}
void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type) void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
{ {
......
...@@ -65,9 +65,7 @@ namespace ngraph ...@@ -65,9 +65,7 @@ namespace ngraph
const Nodes& get_arguments() const { return m_arguments; } const Nodes& get_arguments() const { return m_arguments; }
void clear_arguments() { m_arguments.clear(); } void clear_arguments() { m_arguments.clear(); }
const std::multiset<Node*>& users() const { return m_users; } const std::multiset<Node*>& users() const { return m_users; }
virtual std::string get_node_id() const; virtual std::string get_node_id() const;
/// Return true if this has the same implementing class as node. This /// Return true if this has the same implementing class as node. This
...@@ -80,7 +78,6 @@ namespace ngraph ...@@ -80,7 +78,6 @@ namespace ngraph
std::shared_ptr<const ValueType> get_value_type() { return m_value_type; } std::shared_ptr<const ValueType> get_value_type() { return m_value_type; }
const std::shared_ptr<const ValueType> get_value_type() const { return m_value_type; } const std::shared_ptr<const ValueType> get_value_type() const { return m_value_type; }
void set_value_type(const element::Type& element_type, const Shape& shape) void set_value_type(const element::Type& element_type, const Shape& shape)
{ {
m_value_type = std::make_shared<TensorViewType>(element_type, shape); m_value_type = std::make_shared<TensorViewType>(element_type, shape);
...@@ -108,7 +105,6 @@ namespace ngraph ...@@ -108,7 +105,6 @@ namespace ngraph
const std::vector<descriptor::Input>& get_inputs() const { return m_inputs; } const std::vector<descriptor::Input>& get_inputs() const { return m_inputs; }
std::vector<descriptor::Output>& get_outputs() { return m_outputs; } std::vector<descriptor::Output>& get_outputs() { return m_outputs; }
const std::vector<descriptor::Output>& get_outputs() const { return m_outputs; } const std::vector<descriptor::Output>& get_outputs() const { return m_outputs; }
std::unordered_set<descriptor::Tensor*> liveness_live_list; std::unordered_set<descriptor::Tensor*> liveness_live_list;
std::unordered_set<descriptor::Tensor*> liveness_new_list; std::unordered_set<descriptor::Tensor*> liveness_new_list;
std::unordered_set<descriptor::Tensor*> liveness_free_list; std::unordered_set<descriptor::Tensor*> liveness_free_list;
......
...@@ -19,8 +19,7 @@ using namespace ngraph; ...@@ -19,8 +19,7 @@ using namespace ngraph;
using namespace ngraph::op; using namespace ngraph::op;
const element::Type& BinaryElementwiseArithmetic::propagate_element_types( const element::Type& BinaryElementwiseArithmetic::propagate_element_types(
const element::Type& arg0_element_type, const element::Type& arg0_element_type, const element::Type& arg1_element_type) const
const element::Type& arg1_element_type) const
{ {
if (arg0_element_type != arg1_element_type) if (arg0_element_type != arg1_element_type)
{ {
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -41,11 +41,9 @@ void BinaryElementwiseBuiltin::propagate_types() ...@@ -41,11 +41,9 @@ void BinaryElementwiseBuiltin::propagate_types()
throw ngraph_error("Arguments must have the same tensor view shape"); throw ngraph_error("Arguments must have the same tensor view shape");
} }
const element::Type& result_element_type = const element::Type& result_element_type = propagate_element_types(
propagate_element_types(arg0_tensor_type->get_element_type(), arg0_tensor_type->get_element_type(), arg1_tensor_type->get_element_type());
arg1_tensor_type->get_element_type());
set_value_type_checked(make_shared<TensorViewType>(result_element_type, set_value_type_checked(
arg0_tensor_type->get_shape())); make_shared<TensorViewType>(result_element_type, arg0_tensor_type->get_shape()));
} }
...@@ -19,8 +19,7 @@ using namespace ngraph; ...@@ -19,8 +19,7 @@ using namespace ngraph;
using namespace ngraph::op; using namespace ngraph::op;
const element::Type& BinaryElementwiseComparison::propagate_element_types( const element::Type& BinaryElementwiseComparison::propagate_element_types(
const element::Type& arg0_element_type, const element::Type& arg0_element_type, const element::Type& arg1_element_type) const
const element::Type& arg1_element_type) const
{ {
if (arg0_element_type != arg1_element_type) if (arg0_element_type != arg1_element_type)
{ {
......
...@@ -19,7 +19,8 @@ using namespace ngraph::op; ...@@ -19,7 +19,8 @@ using namespace ngraph::op;
void Broadcast::propagate_types() void Broadcast::propagate_types()
{ {
if (m_arguments.size() != 1){ if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments."); throw ngraph_error("Wrong number of arguments.");
} }
...@@ -42,5 +43,6 @@ void Broadcast::propagate_types() ...@@ -42,5 +43,6 @@ void Broadcast::propagate_types()
{ {
throw ngraph_error("Broadcast arg, shape, and axes are incompatible"); throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
} }
set_value_type_checked(make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape)); set_value_type_checked(
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape));
} }
...@@ -40,7 +40,6 @@ namespace ngraph ...@@ -40,7 +40,6 @@ namespace ngraph
virtual void propagate_types() override; virtual void propagate_types() override;
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; } const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
protected: protected:
Shape m_shape; Shape m_shape;
AxisSet m_broadcast_axes; AxisSet m_broadcast_axes;
......
...@@ -47,7 +47,7 @@ void Concat::propagate_types() ...@@ -47,7 +47,7 @@ void Concat::propagate_types()
size_t concatenation_axis_length = arg0_shape.at(m_concatenation_axis); size_t concatenation_axis_length = arg0_shape.at(m_concatenation_axis);
auto& arg0_element_type = arg0_tensor_view_type->get_element_type(); auto& arg0_element_type = arg0_tensor_view_type->get_element_type();
for(auto i = 1; i < m_arguments.size(); i++) for (auto i = 1; i < m_arguments.size(); i++)
{ {
auto argi_type = m_arguments.at(i)->get_value_type(); auto argi_type = m_arguments.at(i)->get_value_type();
if (nullptr == argi_type) if (nullptr == argi_type)
...@@ -72,11 +72,12 @@ void Concat::propagate_types() ...@@ -72,11 +72,12 @@ void Concat::propagate_types()
throw ngraph_error("Argument element types do not match"); throw ngraph_error("Argument element types do not match");
} }
for(auto j = 0; j < argi_shape.size(); j++) for (auto j = 0; j < argi_shape.size(); j++)
{ {
if (j != m_concatenation_axis && arg0_shape.at(j) != argi_shape.at(j)) if (j != m_concatenation_axis && arg0_shape.at(j) != argi_shape.at(j))
{ {
throw ngraph_error("Arguments to concat do not have same dimension on a non-concatenation axis"); throw ngraph_error(
"Arguments to concat do not have same dimension on a non-concatenation axis");
} }
else if (j == m_concatenation_axis) else if (j == m_concatenation_axis)
{ {
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
/// ///
/// Example: n0 has shape {2,4,2}, and n1 has shape {2,5,2}. Then the output of /// Example: n0 has shape {2,4,2}, and n1 has shape {2,5,2}. Then the output of
/// Concat(Nodes{n0,n1},1) will have shape {2,9,2}. /// Concat(Nodes{n0,n1},1) will have shape {2,9,2}.
Concat(const Nodes& args,size_t concatenation_axis) Concat(const Nodes& args, size_t concatenation_axis)
: Builtin(args) : Builtin(args)
, m_concatenation_axis(concatenation_axis) , m_concatenation_axis(concatenation_axis)
{ {
...@@ -40,7 +40,6 @@ namespace ngraph ...@@ -40,7 +40,6 @@ namespace ngraph
virtual void propagate_types() override; virtual void propagate_types() override;
size_t get_concatenation_axis() const { return m_concatenation_axis; } size_t get_concatenation_axis() const { return m_concatenation_axis; }
protected: protected:
const size_t m_concatenation_axis; const size_t m_concatenation_axis;
}; };
......
...@@ -16,7 +16,10 @@ ...@@ -16,7 +16,10 @@
using namespace ngraph::op; using namespace ngraph::op;
void ScalarConstantBase::propagate_types() {} void ScalarConstantBase::propagate_types()
{
void TensorConstantBase::propagate_types() {} }
void TensorConstantBase::propagate_types()
{
}
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include <sstream> #include <sstream>
#include "ngraph/types/element_type.hpp"
#include "ngraph/runtime/utils.hpp" #include "ngraph/runtime/utils.hpp"
#include "ngraph/types/element_type.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -60,11 +60,7 @@ namespace ngraph ...@@ -60,11 +60,7 @@ namespace ngraph
return ss.str(); return ss.str();
} }
type get_value() const type get_value() const { return m_value; }
{
return m_value;
}
protected: protected:
typename T::type m_value; typename T::type m_value;
}; };
...@@ -113,7 +109,10 @@ namespace ngraph ...@@ -113,7 +109,10 @@ namespace ngraph
return ss.str(); return ss.str();
} }
typename std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> get_value() const { return m_value; } typename std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> get_value() const
{
return m_value;
}
protected: protected:
std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> m_value; std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> m_value;
......
...@@ -56,22 +56,23 @@ void Dot::propagate_types() ...@@ -56,22 +56,23 @@ void Dot::propagate_types()
vector<size_t> result_shape; vector<size_t> result_shape;
result_shape.reserve(arg0_shape.size() + arg1_shape.size() - (is_scalar_mult ? 0 : 2)); result_shape.reserve(arg0_shape.size() + arg1_shape.size() - (is_scalar_mult ? 0 : 2));
for(auto i = 0; i < arg0_shape.size(); i++) for (auto i = 0; i < arg0_shape.size(); i++)
{ {
if(is_scalar_mult || i != arg0_reduction) if (is_scalar_mult || i != arg0_reduction)
{ {
result_shape.push_back(arg0_shape[i]); result_shape.push_back(arg0_shape[i]);
} }
} }
for(auto i = 0; i < arg1_shape.size(); i++) for (auto i = 0; i < arg1_shape.size(); i++)
{ {
if(is_scalar_mult || i != arg1_reduction) if (is_scalar_mult || i != arg1_reduction)
{ {
result_shape.push_back(arg1_shape[i]); result_shape.push_back(arg1_shape[i]);
} }
} }
auto result_type = make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape); auto result_type =
make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
set_value_type_checked(result_type); set_value_type_checked(result_type);
} }
...@@ -39,7 +39,6 @@ namespace ngraph ...@@ -39,7 +39,6 @@ namespace ngraph
virtual void propagate_types() override; virtual void propagate_types() override;
std::shared_ptr<Function> get_function() const { return m_function; } std::shared_ptr<Function> get_function() const { return m_function; }
protected: protected:
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
}; };
......
...@@ -33,7 +33,8 @@ void GetTupleElement::propagate_types() ...@@ -33,7 +33,8 @@ void GetTupleElement::propagate_types()
throw ngraph_error("Argument must be a tuple view"); throw ngraph_error("Argument must be a tuple view");
} }
if (m_n >= arg0_tuple_type->get_element_types().size()){ if (m_n >= arg0_tuple_type->get_element_types().size())
{
throw ngraph_error("Indexing tuple beyond its size"); throw ngraph_error("Indexing tuple beyond its size");
} }
......
...@@ -33,9 +33,7 @@ namespace ngraph ...@@ -33,9 +33,7 @@ namespace ngraph
virtual void propagate_types() override; virtual void propagate_types() override;
virtual std::string description() const override { return "GetTupleElement"; } virtual std::string description() const override { return "GetTupleElement"; }
size_t get_n() const { return m_n; } size_t get_n() const { return m_n; }
protected: protected:
size_t m_n; size_t m_n;
}; };
......
...@@ -31,7 +31,6 @@ namespace ngraph ...@@ -31,7 +31,6 @@ namespace ngraph
{ {
public: public:
virtual std::string description() const override { return "Builtin"; } virtual std::string description() const override { return "Builtin"; }
protected: protected:
Builtin(const std::vector<std::shared_ptr<Node>>& args) Builtin(const std::vector<std::shared_ptr<Node>>& args)
: Node(args) : Node(args)
...@@ -73,8 +72,8 @@ namespace ngraph ...@@ -73,8 +72,8 @@ namespace ngraph
: Builtin(Nodes{arg}) : Builtin(Nodes{arg})
{ {
} }
virtual const element::Type& propagate_element_types( virtual const element::Type&
const element::Type& arg_element_type) const = 0; propagate_element_types(const element::Type& arg_element_type) const = 0;
public: public:
virtual void propagate_types() override; virtual void propagate_types() override;
...@@ -87,8 +86,8 @@ namespace ngraph ...@@ -87,8 +86,8 @@ namespace ngraph
: UnaryElementwiseBuiltin({arg}) : UnaryElementwiseBuiltin({arg})
{ {
} }
virtual const element::Type& propagate_element_types( virtual const element::Type&
const element::Type& arg_element_type) const final override; propagate_element_types(const element::Type& arg_element_type) const final override;
}; };
/// Op(X, Y)[I] = op(X[I], Y[I]) /// Op(X, Y)[I] = op(X[I], Y[I])
...@@ -100,8 +99,8 @@ namespace ngraph ...@@ -100,8 +99,8 @@ namespace ngraph
: Builtin(Nodes{arg0, arg1}) : Builtin(Nodes{arg0, arg1})
{ {
} }
virtual const element::Type& propagate_element_types( virtual const element::Type&
const element::Type& arg0_element_type, propagate_element_types(const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const = 0; const element::Type& arg1_element_type) const = 0;
public: public:
...@@ -111,34 +110,39 @@ namespace ngraph ...@@ -111,34 +110,39 @@ namespace ngraph
class BinaryElementwiseComparison : public BinaryElementwiseBuiltin class BinaryElementwiseComparison : public BinaryElementwiseBuiltin
{ {
public: public:
BinaryElementwiseComparison( BinaryElementwiseComparison(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
virtual std::string description() const override { return "BinaryElementwiseComparison"; } virtual std::string description() const override
{
return "BinaryElementwiseComparison";
}
//virtual void propagate_types() override; //virtual void propagate_types() override;
virtual const element::Type& propagate_element_types( virtual const element::Type&
const element::Type& arg0_element_type, propagate_element_types(const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const override; const element::Type& arg1_element_type) const override;
}; };
class BinaryElementwiseArithmetic : public BinaryElementwiseBuiltin class BinaryElementwiseArithmetic : public BinaryElementwiseBuiltin
{ {
public: public:
BinaryElementwiseArithmetic( BinaryElementwiseArithmetic(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
virtual std::string description() const override { return "BinaryElementwiseArithmetic"; } virtual std::string description() const override
{
return "BinaryElementwiseArithmetic";
}
//virtual void propagate_types() override; //virtual void propagate_types() override;
virtual const element::Type& propagate_element_types( virtual const element::Type& propagate_element_types(
const element::Type& arg0_element_type, const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const element::Type& arg1_element_type) const final override;
const final override;
}; };
} }
} }
...@@ -41,4 +41,6 @@ void Parameter::assign_function(Function* function, size_t index) ...@@ -41,4 +41,6 @@ void Parameter::assign_function(Function* function, size_t index)
m_index = index; m_index = index;
} }
void Parameter::propagate_types() {} void Parameter::propagate_types()
{
}
...@@ -37,7 +37,7 @@ namespace ngraph ...@@ -37,7 +37,7 @@ namespace ngraph
void assign_function(Function* function, size_t index); void assign_function(Function* function, size_t index);
public: public:
Parameter(const std::shared_ptr<ValueType>& value_type=nullptr); Parameter(const std::shared_ptr<ValueType>& value_type = nullptr);
Parameter(const ngraph::element::Type& element_type, const Shape& shape); Parameter(const ngraph::element::Type& element_type, const Shape& shape);
std::string description() const override { return "Parameter"; } std::string description() const override { return "Parameter"; }
......
...@@ -30,7 +30,8 @@ void Reduce::propagate_types() ...@@ -30,7 +30,8 @@ void Reduce::propagate_types()
{ {
throw ngraph_error("Argument to reduce is missing type."); throw ngraph_error("Argument to reduce is missing type.");
} }
auto arg_reductee_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_reductee_type); auto arg_reductee_tensor_view_type =
dynamic_pointer_cast<const TensorViewType>(arg_reductee_type);
if (nullptr == arg_reductee_tensor_view_type) if (nullptr == arg_reductee_tensor_view_type)
{ {
throw ngraph_error("Argument to reduce is not a tensor view"); throw ngraph_error("Argument to reduce is not a tensor view");
...@@ -51,7 +52,8 @@ void Reduce::propagate_types() ...@@ -51,7 +52,8 @@ void Reduce::propagate_types()
throw ngraph_error("Argument for initial value is not a scalar"); throw ngraph_error("Argument for initial value is not a scalar");
} }
if (arg_init_tensor_view_type->get_element_type() != arg_reductee_tensor_view_type->get_element_type()) if (arg_init_tensor_view_type->get_element_type() !=
arg_reductee_tensor_view_type->get_element_type())
{ {
throw ngraph_error("Element types for reductee and initial values do not match"); throw ngraph_error("Element types for reductee and initial values do not match");
} }
...@@ -99,5 +101,6 @@ void Reduce::propagate_types() ...@@ -99,5 +101,6 @@ void Reduce::propagate_types()
throw ngraph_error("Return type from reduction function does not match expected"); throw ngraph_error("Return type from reduction function does not match expected");
} }
set_value_type_checked(make_shared<TensorViewType>(arg_reductee_tensor_view_type->get_element_type(), result_shape)); set_value_type_checked(make_shared<TensorViewType>(
arg_reductee_tensor_view_type->get_element_type(), result_shape));
} }
...@@ -31,7 +31,7 @@ namespace ngraph ...@@ -31,7 +31,7 @@ namespace ngraph
const std::shared_ptr<Node>& arg_init, const std::shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& reduction_function, const std::shared_ptr<Function>& reduction_function,
const AxisSet& reduction_axes) const AxisSet& reduction_axes)
: Builtin({arg_reductee,arg_init}) : Builtin({arg_reductee, arg_init})
, m_reduction_function(reduction_function) , m_reduction_function(reduction_function)
, m_reduction_axes(reduction_axes) , m_reduction_axes(reduction_axes)
{ {
...@@ -40,9 +40,11 @@ namespace ngraph ...@@ -40,9 +40,11 @@ namespace ngraph
virtual std::string description() const override { return "Reduce"; } virtual std::string description() const override { return "Reduce"; }
virtual void propagate_types() override; virtual void propagate_types() override;
std::shared_ptr<Function> get_reduction_function() const { return m_reduction_function; } std::shared_ptr<Function> get_reduction_function() const
{
return m_reduction_function;
}
const AxisSet& get_reduction_axes() const { return m_reduction_axes; } const AxisSet& get_reduction_axes() const { return m_reduction_axes; }
protected: protected:
std::shared_ptr<Function> m_reduction_function; std::shared_ptr<Function> m_reduction_function;
AxisSet m_reduction_axes; AxisSet m_reduction_axes;
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -42,8 +42,8 @@ void Select::propagate_types() ...@@ -42,8 +42,8 @@ void Select::propagate_types()
{ {
throw ngraph_error("Argument 0 for arithmetic operators must have boolean element type"); throw ngraph_error("Argument 0 for arithmetic operators must have boolean element type");
} }
if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape() if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape() ||
|| arg0_tensor_type->get_shape() != arg2_tensor_type->get_shape()) arg0_tensor_type->get_shape() != arg2_tensor_type->get_shape())
{ {
throw ngraph_error("Arguments must have the same tensor view shape"); throw ngraph_error("Arguments must have the same tensor view shape");
} }
...@@ -54,4 +54,3 @@ void Select::propagate_types() ...@@ -54,4 +54,3 @@ void Select::propagate_types()
set_value_type_checked(arg1_tensor_type); set_value_type_checked(arg1_tensor_type);
} }
...@@ -20,8 +20,8 @@ using namespace std; ...@@ -20,8 +20,8 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace ngraph::op; using namespace ngraph::op;
const element::Type& UnaryElementwiseArithmetic::propagate_element_types( const element::Type&
const element::Type& arg_element_type) const UnaryElementwiseArithmetic::propagate_element_types(const element::Type& arg_element_type) const
{ {
if (arg_element_type == element::Bool::element_type()) if (arg_element_type == element::Bool::element_type())
{ {
......
...@@ -37,6 +37,6 @@ void UnaryElementwiseBuiltin::propagate_types() ...@@ -37,6 +37,6 @@ void UnaryElementwiseBuiltin::propagate_types()
const element::Type& result_element_type = const element::Type& result_element_type =
propagate_element_types(arg_tensor_type->get_element_type()); propagate_element_types(arg_tensor_type->get_element_type());
set_value_type_checked(make_shared<TensorViewType>(result_element_type, set_value_type_checked(
arg_tensor_type->get_shape())); make_shared<TensorViewType>(result_element_type, arg_tensor_type->get_shape()));
} }
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include <fstream> #include <fstream>
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -51,7 +51,6 @@ bool pass::DumpSorted::run_on_call_list(list<Node*>& nodes) ...@@ -51,7 +51,6 @@ bool pass::DumpSorted::run_on_call_list(list<Node*>& nodes)
out << join(outputs); out << join(outputs);
out << "\n"; out << "\n";
for (const Tensor* tensor : node->liveness_live_list) for (const Tensor* tensor : node->liveness_live_list)
{ {
out << " L " << tensor->get_name() << "\n"; out << " L " << tensor->get_name() << "\n";
......
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
#include <sstream> #include <sstream>
#include <unordered_set> #include <unordered_set>
#include "ngraph/log.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "ngraph/log.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -31,7 +31,7 @@ bool pass::Liveness::run_on_call_list(list<Node*>& ops) ...@@ -31,7 +31,7 @@ bool pass::Liveness::run_on_call_list(list<Node*>& ops)
{ {
unordered_set<Tensor*> currently_live; unordered_set<Tensor*> currently_live;
for(auto it=ops.rbegin(); it!=ops.rend(); it++) for (auto it = ops.rbegin(); it != ops.rend(); it++)
{ {
Node* node = *it; Node* node = *it;
node->liveness_live_list.clear(); node->liveness_live_list.clear();
...@@ -143,11 +143,8 @@ void pass::Liveness::check_dependencies( ...@@ -143,11 +143,8 @@ void pass::Liveness::check_dependencies(
bool pass::Liveness::is_temporary(const Tensor& tensor) bool pass::Liveness::is_temporary(const Tensor& tensor)
{ {
return return tensor.is_persistent() == false && tensor.is_input() == false &&
tensor.is_persistent() == false tensor.is_output() == false;
&& tensor.is_input() == false
&& tensor.is_output() == false
;
// && tensor.is_constant() == false // && tensor.is_constant() == false
// && tensor.is_compile_only() == false; // && tensor.is_compile_only() == false;
} }
...@@ -170,4 +167,3 @@ void pass::Liveness::validate_liveness(const list<Node*>& ops) ...@@ -170,4 +167,3 @@ void pass::Liveness::validate_liveness(const list<Node*>& ops)
dead_tensors.insert(node->liveness_free_list.begin(), node->liveness_free_list.end()); dead_tensors.insert(node->liveness_free_list.begin(), node->liveness_free_list.end());
} }
} }
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#pragma once #pragma once
#include "ngraph/pass/call_pass.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/pass/call_pass.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include "ngraph/function.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/function.hpp" #include "ngraph/pass/manager.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include <vector>
#include <memory>
#include <list> #include <list>
#include <memory>
#include <vector>
#include "ngraph/pass/call_pass.hpp" #include "ngraph/pass/call_pass.hpp"
#include "ngraph/pass/tree_pass.hpp" #include "ngraph/pass/tree_pass.hpp"
...@@ -59,7 +59,7 @@ public: ...@@ -59,7 +59,7 @@ public:
void initialize_default_passes(); void initialize_default_passes();
template<typename T, class... Args> template <typename T, class... Args>
void register_pass(Args... args) void register_pass(Args... args)
{ {
static_assert(std::is_base_of<pass::Base, T>::value, "pass not derived from pass base"); static_assert(std::is_base_of<pass::Base, T>::value, "pass not derived from pass base");
......
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
#include <exception> #include <exception>
#include <sstream> #include <sstream>
#include "ngraph/log.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp" #include "ngraph/pass/memory_layout.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -69,7 +69,6 @@ pass::MemoryManager::node::node(size_t size, block_state state) ...@@ -69,7 +69,6 @@ pass::MemoryManager::node::node(size_t size, block_state state)
: m_size{size} : m_size{size}
, m_state{state} , m_state{state}
{ {
} }
pass::MemoryManager::MemoryManager(size_t alignment) pass::MemoryManager::MemoryManager(size_t alignment)
...@@ -84,14 +83,10 @@ pass::MemoryManager::MemoryManager(size_t alignment) ...@@ -84,14 +83,10 @@ pass::MemoryManager::MemoryManager(size_t alignment)
size_t pass::MemoryManager::allocate(size_t size) size_t pass::MemoryManager::allocate(size_t size)
{ {
size_t rc; size_t rc;
switch(m_scheme) switch (m_scheme)
{ {
case allocation_scheme::FIRST_FIT: case allocation_scheme::FIRST_FIT: rc = first_fit(size); break;
rc = first_fit(size); case allocation_scheme::BEST_FIT: rc = best_fit(size); break;
break;
case allocation_scheme::BEST_FIT:
rc = best_fit(size);
break;
} }
return rc; return rc;
} }
...@@ -103,7 +98,7 @@ size_t pass::MemoryManager::best_fit(size_t size) ...@@ -103,7 +98,7 @@ size_t pass::MemoryManager::best_fit(size_t size)
size_t min_delta = numeric_limits<size_t>::max(); size_t min_delta = numeric_limits<size_t>::max();
auto best_fit = m_node_list.end(); auto best_fit = m_node_list.end();
size_t best_offset = offset; size_t best_offset = offset;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it) for (auto it = m_node_list.begin(); it != m_node_list.end(); ++it)
{ {
if (it->m_state == block_state::FREE && it->m_size >= size) if (it->m_state == block_state::FREE && it->m_size >= size)
{ {
...@@ -143,7 +138,7 @@ size_t pass::MemoryManager::first_fit(size_t size) ...@@ -143,7 +138,7 @@ size_t pass::MemoryManager::first_fit(size_t size)
size = align(size, m_alignment); size = align(size, m_alignment);
size_t offset = 0; size_t offset = 0;
bool found = false; bool found = false;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it) for (auto it = m_node_list.begin(); it != m_node_list.end(); ++it)
{ {
if (it->m_state == block_state::FREE && it->m_size >= size) if (it->m_state == block_state::FREE && it->m_size >= size)
{ {
...@@ -176,7 +171,7 @@ void pass::MemoryManager::free(size_t offset) ...@@ -176,7 +171,7 @@ void pass::MemoryManager::free(size_t offset)
{ {
size_t search_offset = 0; size_t search_offset = 0;
bool found = false; bool found = false;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it) for (auto it = m_node_list.begin(); it != m_node_list.end(); ++it)
{ {
if (offset == search_offset) if (offset == search_offset)
{ {
......
...@@ -62,12 +62,11 @@ public: ...@@ -62,12 +62,11 @@ public:
node(size_t size, block_state state); node(size_t size, block_state state);
bool is_free() const { return m_state == block_state::FREE; } bool is_free() const { return m_state == block_state::FREE; }
size_t m_size; size_t m_size;
block_state m_state; block_state m_state;
}; };
MemoryManager(size_t alignment=1); MemoryManager(size_t alignment = 1);
// memory_manager& alignment(size_t a); // memory_manager& alignment(size_t a);
size_t allocate(size_t size); size_t allocate(size_t size);
...@@ -81,11 +80,8 @@ public: ...@@ -81,11 +80,8 @@ public:
std::list<node>::iterator end() { return m_node_list.end(); } std::list<node>::iterator end() { return m_node_list.end(); }
std::list<node>::const_iterator begin() const { return m_node_list.cbegin(); } std::list<node>::const_iterator begin() const { return m_node_list.cbegin(); }
std::list<node>::const_iterator end() const { return m_node_list.cend(); } std::list<node>::const_iterator end() const { return m_node_list.cend(); }
const std::list<node>& get_node_list() const { return m_node_list; } const std::list<node>& get_node_list() const { return m_node_list; }
size_t max_allocated() const { return m_max_allocated; } size_t max_allocated() const { return m_max_allocated; }
private: private:
size_t first_fit(size_t size); size_t first_fit(size_t size);
size_t best_fit(size_t size); size_t best_fit(size_t size);
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <algorithm>
#include <fstream> #include <fstream>
#include <unordered_set>
#include <unordered_map> #include <unordered_map>
#include <algorithm> #include <unordered_set>
#include "memory_visualize.hpp" #include "memory_visualize.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
...@@ -154,8 +154,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>& ...@@ -154,8 +154,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>&
} }
i++; i++;
} }
sort(tensor_set.begin(), tensor_set.end(), [](const Tensor* t1, const Tensor* t2) sort(tensor_set.begin(), tensor_set.end(), [](const Tensor* t1, const Tensor* t2) {
{
return t1->size() < t2->size(); return t1->size() < t2->size();
}); });
for (const Tensor* tensor : tensor_set) for (const Tensor* tensor : tensor_set)
...@@ -206,12 +205,16 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod ...@@ -206,12 +205,16 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
y += line_spacing; y += line_spacing;
size_t x1 = offset; size_t x1 = offset;
size_t x2 = ((usage / memory_footprint) * scale) + offset; size_t x2 = ((usage / memory_footprint) * scale) + offset;
file << "<text x=\"" << 0 << "\" y=\"" << y + text_offset << "\" fill=\"" << "black" << "\">" << node->get_node_id() << "</text>\n"; file << "<text x=\"" << 0 << "\" y=\"" << y + text_offset << "\" fill=\""
file << "<line x1=\"" << x1 << "\" y1=\"" << y << "\" x2=\"" << x2 << "\" y2=\"" << y << "\""; << "black"
<< "\">" << node->get_node_id() << "</text>\n";
file << "<line x1=\"" << x1 << "\" y1=\"" << y << "\" x2=\"" << x2 << "\" y2=\"" << y
<< "\"";
file << " style=\"stroke:forestgreen;stroke-width:" << stroke_width << "\" />\n"; file << " style=\"stroke:forestgreen;stroke-width:" << stroke_width << "\" />\n";
x1 = x2; x1 = x2;
x2 = ((footprint / memory_footprint) * scale) + offset; x2 = ((footprint / memory_footprint) * scale) + offset;
file << "<line x1=\"" << x1 << "\" y1=\"" << y << "\" x2=\"" << x2 << "\" y2=\"" << y << "\""; file << "<line x1=\"" << x1 << "\" y1=\"" << y << "\" x2=\"" << x2 << "\" y2=\"" << y
<< "\"";
file << " style=\"stroke:firebrick;stroke-width:" << stroke_width << "\" />\n"; file << " style=\"stroke:firebrick;stroke-width:" << stroke_width << "\" />\n";
} }
file << "</svg>\n"; file << "</svg>\n";
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include <iostream>
#include <limits> #include <limits>
#include <list> #include <list>
#include <iostream>
#include "ngraph/pass/call_pass.hpp" #include "ngraph/pass/call_pass.hpp"
......
...@@ -27,6 +27,7 @@ namespace ngraph ...@@ -27,6 +27,7 @@ namespace ngraph
class ngraph::pass::Base class ngraph::pass::Base
{ {
friend class Manager; friend class Manager;
public: public:
protected: protected:
ManagerState& get_state(); ManagerState& get_state();
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include <fstream> #include <fstream>
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -24,8 +24,7 @@ using namespace std; ...@@ -24,8 +24,7 @@ using namespace std;
bool pass::VisualizeTree::run_on_tree(std::shared_ptr<Node> base_node) bool pass::VisualizeTree::run_on_tree(std::shared_ptr<Node> base_node)
{ {
// map<size_t, list<node_ptr>> dependent_nodes; // map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(base_node, [&](Node* node) traverse_nodes(base_node, [&](Node* node) {
{
for (auto arg : node->get_arguments()) for (auto arg : node->get_arguments())
{ {
m_ss << add_attributes(arg.get()); m_ss << add_attributes(arg.get());
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include <set>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <set>
#include "ngraph/pass/tree_pass.hpp" #include "ngraph/pass/tree_pass.hpp"
......
...@@ -48,23 +48,19 @@ namespace ngraph ...@@ -48,23 +48,19 @@ namespace ngraph
void tensor_call(const TensorViewPtrs& inputs, const TensorViewPtrs& outputs); void tensor_call(const TensorViewPtrs& inputs, const TensorViewPtrs& outputs);
void set_return() { m_return = true; } void set_return() { m_return = true; }
std::shared_ptr<TensorView> get_tensor_view(size_t i) { return m_tensor_views[i]; } std::shared_ptr<TensorView> get_tensor_view(size_t i) { return m_tensor_views[i]; }
template <typename ET> template <typename ET>
ParameterizedTensorView<ET>* get_parameterized_tensor_view(size_t i) ParameterizedTensorView<ET>* get_parameterized_tensor_view(size_t i)
{ {
return m_tensor_views[i]->get_parameterized_tensor_view<ET>(); return m_tensor_views[i]->get_parameterized_tensor_view<ET>();
} }
template<typename ET> template <typename ET>
typename ET::type* get_tensor_view_data(size_t i) typename ET::type* get_tensor_view_data(size_t i)
{ {
return &get_parameterized_tensor_view<ET>(i)->get_vector()[0]; return &get_parameterized_tensor_view<ET>(i)->get_vector()[0];
} }
protected: protected:
size_t m_n_inputs; size_t m_n_inputs;
size_t m_n_outputs; size_t m_n_outputs;
......
...@@ -38,7 +38,8 @@ namespace ngraph ...@@ -38,7 +38,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = Eigen::abs(EigenArray1d<ET>(call_frame, m_arg)); EigenArray1d<ET>(call_frame, m_out) =
Eigen::abs(EigenArray1d<ET>(call_frame, m_arg));
} }
protected: protected:
......
...@@ -29,8 +29,7 @@ namespace ngraph ...@@ -29,8 +29,7 @@ namespace ngraph
class BroadcastScalarInstruction : public Instruction class BroadcastScalarInstruction : public Instruction
{ {
public: public:
BroadcastScalarInstruction(const TensorViewInfo& arg, BroadcastScalarInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -42,7 +41,7 @@ namespace ngraph ...@@ -42,7 +41,7 @@ namespace ngraph
// pull it out as a vector. This works because of the way // pull it out as a vector. This works because of the way
// fmt::V computes sizes---it lumps together any higher // fmt::V computes sizes---it lumps together any higher
// dimensions---while fmt::M ignores them. // dimensions---while fmt::M ignores them.
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg)(0,0); EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg)(0, 0);
} }
protected: protected:
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/eigen/utils.hpp" #include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/instruction.hpp" #include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
...@@ -29,7 +29,9 @@ namespace ngraph ...@@ -29,7 +29,9 @@ namespace ngraph
class CallInstruction : public Instruction class CallInstruction : public Instruction
{ {
public: public:
CallInstruction(std::shared_ptr<ExternalFunction> ef,std::vector<TensorViewInfo> in, std::vector<TensorViewInfo> out) CallInstruction(std::shared_ptr<ExternalFunction> ef,
std::vector<TensorViewInfo> in,
std::vector<TensorViewInfo> out)
: m_external_function(ef) : m_external_function(ef)
, m_in(in) , m_in(in)
, m_out(out) , m_out(out)
...@@ -51,7 +53,7 @@ namespace ngraph ...@@ -51,7 +53,7 @@ namespace ngraph
{ {
outputs.push_back(call_frame.get_tensor_view(out.get_index())); outputs.push_back(call_frame.get_tensor_view(out.get_index()));
} }
(*cf)(inputs,outputs); (*cf)(inputs, outputs);
} }
protected: protected:
......
...@@ -46,8 +46,10 @@ namespace ngraph ...@@ -46,8 +46,10 @@ namespace ngraph
{ {
EigenVector<ET> out(call_frame, m_out); EigenVector<ET> out(call_frame, m_out);
size_t concat_pos = 0; size_t concat_pos = 0;
for (size_t i = 0; i < m_args.size(); i++){ for (size_t i = 0; i < m_args.size(); i++)
out.segment(concat_pos, m_sizes[i]) << EigenVector<ET>(call_frame, m_args.at(i)); {
out.segment(concat_pos, m_sizes[i])
<< EigenVector<ET>(call_frame, m_args.at(i));
concat_pos += m_sizes[i]; concat_pos += m_sizes[i];
} }
} }
......
...@@ -30,7 +30,8 @@ namespace ngraph ...@@ -30,7 +30,8 @@ namespace ngraph
class ConstantInstruction : public Instruction class ConstantInstruction : public Instruction
{ {
public: public:
ConstantInstruction(const std::vector<typename ET::type> value, const TensorViewInfo& out) ConstantInstruction(const std::vector<typename ET::type> value,
const TensorViewInfo& out)
: m_value(value) : m_value(value)
, m_out(out) , m_out(out)
{ {
...@@ -38,7 +39,8 @@ namespace ngraph ...@@ -38,7 +39,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
call_frame.get_parameterized_tensor_view<ET>(m_out.get_index())->get_vector() = m_value; call_frame.get_parameterized_tensor_view<ET>(m_out.get_index())->get_vector() =
m_value;
} }
protected: protected:
......
...@@ -40,8 +40,9 @@ namespace ngraph ...@@ -40,8 +40,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) << EigenArray1d<ET>(call_frame, m_out)
EigenVector<ET>(call_frame, m_arg0).dot(EigenVector<ET>(call_frame, m_arg1)); << EigenVector<ET>(call_frame, m_arg0)
.dot(EigenVector<ET>(call_frame, m_arg1));
} }
protected: protected:
......
...@@ -37,7 +37,8 @@ namespace ngraph ...@@ -37,7 +37,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = Eigen::log(EigenArray1d<ET, fmt::V>(call_frame, m_arg)); EigenArray1d<ET, fmt::V>(call_frame, m_out) =
Eigen::log(EigenArray1d<ET, fmt::V>(call_frame, m_arg));
} }
protected: protected:
......
...@@ -27,7 +27,6 @@ namespace ngraph ...@@ -27,7 +27,6 @@ namespace ngraph
{ {
public: public:
ReturnInstruction() {} ReturnInstruction() {}
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
call_frame.set_return(); call_frame.set_return();
......
...@@ -45,8 +45,8 @@ namespace ngraph ...@@ -45,8 +45,8 @@ namespace ngraph
// fmt::V computes sizes---it lumps together any higher // fmt::V computes sizes---it lumps together any higher
// dimensions---while fmt::M ignores them. // dimensions---while fmt::M ignores them.
EigenVector<ET>(call_frame, m_out) = EigenVector<ET>(call_frame, m_out) =
call_frame.get_tensor_view_data<ET>(m_arg0.get_index())[0] call_frame.get_tensor_view_data<ET>(m_arg0.get_index())[0] *
* EigenVector<ET>(call_frame, m_arg1); EigenVector<ET>(call_frame, m_arg1);
} }
protected: protected:
......
...@@ -40,7 +40,8 @@ namespace ngraph ...@@ -40,7 +40,8 @@ namespace ngraph
using EigenArrayBase = Eigen::Map<DynamicArray<ET>, 0, DynamicStrides>; using EigenArrayBase = Eigen::Map<DynamicArray<ET>, 0, DynamicStrides>;
template <typename ET> template <typename ET>
using DynamicMatrix = Eigen::Matrix<typename ET::type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; using DynamicMatrix =
Eigen::Matrix<typename ET::type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
template <typename ET> template <typename ET>
using EigenMatrixBase = Eigen::Map<DynamicMatrix<ET>, 0, DynamicStrides>; using EigenMatrixBase = Eigen::Map<DynamicMatrix<ET>, 0, DynamicStrides>;
......
...@@ -97,7 +97,8 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -97,7 +97,8 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
const std::vector<TensorViewInfo>& out) const std::vector<TensorViewInfo>& out)
#define REGISTER_INSTRUCTION(op_class, instr_class, ...) \ #define REGISTER_INSTRUCTION(op_class, instr_class, ...) \
REGISTER_TO_OP_MAP(op_class) { \ REGISTER_TO_OP_MAP(op_class) \
{ \
ef->get_instructions()->push_back(make_shared<instr_class>(__VA_ARGS__)); \ ef->get_instructions()->push_back(make_shared<instr_class>(__VA_ARGS__)); \
} }
...@@ -146,8 +147,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -146,8 +147,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
auto broadcast = static_cast<const op::Broadcast*>(n); auto broadcast = static_cast<const op::Broadcast*>(n);
auto arg_tensor_type = auto arg_tensor_type = dynamic_pointer_cast<const TensorViewType>(
dynamic_pointer_cast<const TensorViewType>(n->get_arguments().at(0)->get_value_type()); n->get_arguments().at(0)->get_value_type());
assert(nullptr != arg_tensor_type); assert(nullptr != arg_tensor_type);
auto result_tensor_type = auto result_tensor_type =
...@@ -175,18 +176,22 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -175,18 +176,22 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
if (broadcast->get_broadcast_axes() == AxisSet{1}) if (broadcast->get_broadcast_axes() == AxisSet{1})
{ {
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::BroadcastVectorColwiseInstruction<element::Float32>>( make_shared<
runtime::eigen::BroadcastVectorColwiseInstruction<element::Float32>>(
in[0], out[0])); in[0], out[0]));
} }
else if (broadcast->get_broadcast_axes() == AxisSet{0}) else if (broadcast->get_broadcast_axes() == AxisSet{0})
{ {
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::BroadcastVectorRowwiseInstruction<element::Float32>>( make_shared<
runtime::eigen::BroadcastVectorRowwiseInstruction<element::Float32>>(
in[0], out[0])); in[0], out[0]));
} }
else else
{ {
throw ngraph_error("Internal error: axis set for vector-matrix broadcast is neither {0} or {1}"); throw ngraph_error(
"Internal error: axis set for vector-matrix broadcast is neither {0} or "
"{1}");
} }
} }
else else
...@@ -206,8 +211,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -206,8 +211,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
if (result_shape.size() == 1) if (result_shape.size() == 1)
{ {
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::ConcatVectorInstruction<element::Float32>>( make_shared<runtime::eigen::ConcatVectorInstruction<element::Float32>>(in,
in, out[0])); out[0]));
} }
else if (result_shape.size() == 2) else if (result_shape.size() == 2)
{ {
...@@ -286,7 +291,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -286,7 +291,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
}; };
// Parameter is a "runtime no-op" because the output tensor has already been filled. // Parameter is a "runtime no-op" because the output tensor has already been filled.
REGISTER_TO_OP_MAP(op::Parameter) {}; REGISTER_TO_OP_MAP(op::Parameter){};
// GetTupleElement will be spliced out, with the users of out redirected to in's source, but, for now, we need to copy. // GetTupleElement will be spliced out, with the users of out redirected to in's source, but, for now, we need to copy.
REGISTER_TO_OP_MAP(op::GetTupleElement) REGISTER_TO_OP_MAP(op::GetTupleElement)
...@@ -322,20 +327,16 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -322,20 +327,16 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
} }
catch (const std::out_of_range) catch (const std::out_of_range)
{ {
external = make_shared<ngraph::runtime::ExternalFunction>( external =
function_call->get_function()); make_shared<ngraph::runtime::ExternalFunction>(function_call->get_function());
function_map.insert({function,external}); function_map.insert({function, external});
} }
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::CallInstruction>(external,in,out)); make_shared<runtime::eigen::CallInstruction>(external, in, out));
};
REGISTER_TO_OP_MAP(op::Reduce)
{
throw ngraph_error("op::Reduce not implemented yet");
}; };
REGISTER_TO_OP_MAP(op::Reduce) { throw ngraph_error("op::Reduce not implemented yet"); };
initialized = true; initialized = true;
} }
return op_map; return op_map;
......
...@@ -28,7 +28,8 @@ namespace ngraph ...@@ -28,7 +28,8 @@ namespace ngraph
{ {
class ExternalFunction class ExternalFunction
{ {
using FunctionMap = std::unordered_map<std::shared_ptr<Function>,std::shared_ptr<ExternalFunction>>; using FunctionMap =
std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<ExternalFunction>>;
using OpFunction = std::function<void(const ngraph::Node*, using OpFunction = std::function<void(const ngraph::Node*,
ExternalFunction*, ExternalFunction*,
...@@ -50,7 +51,6 @@ namespace ngraph ...@@ -50,7 +51,6 @@ namespace ngraph
// Release original function's resources // Release original function's resources
void release_function() { m_function = nullptr; } void release_function() { m_function = nullptr; }
protected: protected:
void compile(); void compile();
void compile(FunctionMap& function_map); void compile(FunctionMap& function_map);
......
...@@ -61,7 +61,6 @@ namespace ngraph ...@@ -61,7 +61,6 @@ namespace ngraph
// For getting the data out // For getting the data out
storage_type& get_vector() { return m_vector; } storage_type& get_vector() { return m_vector; }
protected: protected:
storage_type m_vector; storage_type m_vector;
}; };
......
...@@ -39,9 +39,7 @@ namespace ngraph ...@@ -39,9 +39,7 @@ namespace ngraph
public: public:
TensorView() {} TensorView() {}
virtual ~TensorView() {} virtual ~TensorView() {}
template <typename ET> template <typename ET>
ParameterizedTensorView<ET>* get_parameterized_tensor_view() ParameterizedTensorView<ET>* get_parameterized_tensor_view()
{ {
...@@ -65,7 +63,6 @@ namespace ngraph ...@@ -65,7 +63,6 @@ namespace ngraph
} }
const Shape& get_shape() { return m_descriptor->get_tensor_view_type()->get_shape(); } const Shape& get_shape() { return m_descriptor->get_tensor_view_type()->get_shape(); }
protected: protected:
std::shared_ptr<ngraph::descriptor::TensorView> m_descriptor; std::shared_ptr<ngraph::descriptor::TensorView> m_descriptor;
}; };
......
...@@ -34,7 +34,6 @@ namespace ngraph ...@@ -34,7 +34,6 @@ namespace ngraph
} }
size_t get_index() const { return m_index; } size_t get_index() const { return m_index; }
std::shared_ptr<ngraph::descriptor::layout::TensorViewLayout> std::shared_ptr<ngraph::descriptor::layout::TensorViewLayout>
get_tensor_view_layout() const get_tensor_view_layout() const
{ {
......
...@@ -40,8 +40,7 @@ namespace ngraph ...@@ -40,8 +40,7 @@ namespace ngraph
return m_descriptor; return m_descriptor;
} }
virtual void virtual void collect_tensor_views(std::vector<std::shared_ptr<TensorView>>& views,
collect_tensor_views(std::vector<std::shared_ptr<TensorView>>& views,
const std::shared_ptr<Value>& value) const override; const std::shared_ptr<Value>& value) const override;
protected: protected:
......
...@@ -30,7 +30,6 @@ namespace ngraph ...@@ -30,7 +30,6 @@ namespace ngraph
{ {
public: public:
virtual ~Value() {} virtual ~Value() {}
/// @brief The compile-time descriptor for this value. /// @brief The compile-time descriptor for this value.
virtual std::shared_ptr<ngraph::descriptor::Value> get_descriptor() const = 0; virtual std::shared_ptr<ngraph::descriptor::Value> get_descriptor() const = 0;
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include <cmath> #include <cmath>
#include <iostream> #include <iostream>
#include "ngraph/types/element_type.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph; using namespace ngraph;
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -39,7 +39,8 @@ bool TensorViewType::operator==(const ValueType& that) const ...@@ -39,7 +39,8 @@ bool TensorViewType::operator==(const ValueType& that) const
return true; return true;
} }
void TensorViewType::collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const void TensorViewType::collect_tensor_views(
std::vector<std::shared_ptr<const TensorViewType>>& views) const
{ {
views.push_back(shared_from_this()); views.push_back(shared_from_this());
} }
...@@ -54,9 +55,10 @@ bool TupleType::operator==(const ValueType& that) const ...@@ -54,9 +55,10 @@ bool TupleType::operator==(const ValueType& that) const
return that_tvt->get_element_types() == get_element_types(); return that_tvt->get_element_types() == get_element_types();
} }
void TupleType::collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const void TupleType::collect_tensor_views(
std::vector<std::shared_ptr<const TensorViewType>>& views) const
{ {
for(auto elt : m_element_types) for (auto elt : m_element_types)
{ {
elt->collect_tensor_views(views); elt->collect_tensor_views(views);
} }
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/types/element_type.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/types/element_type.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,12 +35,10 @@ namespace ngraph ...@@ -35,12 +35,10 @@ namespace ngraph
protected: protected:
ValueType() {} ValueType() {}
public: public:
virtual ~ValueType() {} virtual ~ValueType() {}
virtual bool operator==(const ValueType& that) const = 0; virtual bool operator==(const ValueType& that) const = 0;
bool operator!=(const ValueType& that) const { return !(*this == that); } bool operator!=(const ValueType& that) const { return !(*this == that); }
/// Add tensor views in depth-first order. /// Add tensor views in depth-first order.
virtual void collect_tensor_views( virtual void collect_tensor_views(
std::vector<std::shared_ptr<const TensorViewType>>& views) const = 0; std::vector<std::shared_ptr<const TensorViewType>>& views) const = 0;
...@@ -62,7 +60,6 @@ namespace ngraph ...@@ -62,7 +60,6 @@ namespace ngraph
const element::Type& get_element_type() const { return m_element_type; } const element::Type& get_element_type() const { return m_element_type; }
const Shape& get_shape() const { return m_shape; } const Shape& get_shape() const { return m_shape; }
virtual bool operator==(const ValueType& that) const override; virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views( virtual void collect_tensor_views(
std::vector<std::shared_ptr<const TensorViewType>>& views) const override; std::vector<std::shared_ptr<const TensorViewType>>& views) const override;
...@@ -80,7 +77,6 @@ namespace ngraph ...@@ -80,7 +77,6 @@ namespace ngraph
public: public:
/// Construct empty tuple and add value types later. /// Construct empty tuple and add value types later.
TupleType() {} TupleType() {}
/// @param element_types A vector of types for the tuple elements /// @param element_types A vector of types for the tuple elements
TupleType(const std::vector<std::shared_ptr<const ValueType>>& element_types) TupleType(const std::vector<std::shared_ptr<const ValueType>>& element_types)
: m_element_types(element_types) : m_element_types(element_types)
...@@ -91,7 +87,10 @@ namespace ngraph ...@@ -91,7 +87,10 @@ namespace ngraph
{ {
return m_element_types; return m_element_types;
} }
std::vector<std::shared_ptr<const ValueType>> set_element_types() { return m_element_types; } std::vector<std::shared_ptr<const ValueType>> set_element_types()
{
return m_element_types;
}
virtual bool operator==(const ValueType& that) const override; virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views( virtual void collect_tensor_views(
......
...@@ -12,15 +12,15 @@ ...@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <iomanip>
#include <map>
#include <deque> #include <deque>
#include <forward_list> #include <forward_list>
#include <iomanip>
#include <map>
#include <unordered_set> #include <unordered_set>
#include "ngraph/util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -135,8 +135,7 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list) ...@@ -135,8 +135,7 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
return seed; return seed;
} }
void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, std::function<void(Node*)> f)
std::function<void(Node*)> f)
{ {
std::unordered_set<Node*> instances_seen; std::unordered_set<Node*> instances_seen;
deque<Node*> stack; deque<Node*> stack;
...@@ -151,7 +150,10 @@ void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, ...@@ -151,7 +150,10 @@ void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p,
f(n); f(n);
} }
stack.pop_front(); stack.pop_front();
for (auto arg : n->get_arguments()) { stack.push_front(arg.get()); } for (auto arg : n->get_arguments())
{
stack.push_front(arg.get());
}
} }
} }
...@@ -159,10 +161,7 @@ void ngraph::free_nodes(shared_ptr<Node> p) ...@@ -159,10 +161,7 @@ void ngraph::free_nodes(shared_ptr<Node> p)
{ {
std::deque<Node*> sorted_list; std::deque<Node*> sorted_list;
traverse_nodes(p, [&](Node* n) traverse_nodes(p, [&](Node* n) { sorted_list.push_front(n); });
{
sorted_list.push_front(n);
});
for (Node* n : sorted_list) for (Node* n : sorted_list)
{ {
......
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
#include <chrono> #include <chrono>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
namespace ngraph namespace ngraph
{ {
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#include <list> #include <list>
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/visualize.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "ngraph/visualize.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
......
...@@ -33,8 +33,10 @@ TEST(build_graph, build_simple) ...@@ -33,8 +33,10 @@ TEST(build_graph, build_simple)
ASSERT_EQ(dot->get_arguments()[0], arg2); ASSERT_EQ(dot->get_arguments()[0], arg2);
ASSERT_EQ(dot->get_arguments()[1], arg0); ASSERT_EQ(dot->get_arguments()[1], arg0);
auto result_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{10,32,7}); auto result_type =
auto cluster_0 = make_shared<Function>(dot, result_type, op::Parameters{arg0, arg1, arg2, arg3}); make_shared<TensorViewType>(element::Float32::element_type(), Shape{10, 32, 7});
auto cluster_0 =
make_shared<Function>(dot, result_type, op::Parameters{arg0, arg1, arg2, arg3});
ASSERT_EQ(cluster_0->get_result(), dot); ASSERT_EQ(cluster_0->get_result(), dot);
} }
...@@ -182,4 +184,6 @@ TEST(build_graph, set_value_type_checked) ...@@ -182,4 +184,6 @@ TEST(build_graph, set_value_type_checked)
} }
// Check argument inverses // Check argument inverses
TEST(build_graph, arg_inverse) {} TEST(build_graph, arg_inverse)
{
}
This diff is collapsed.
...@@ -13,12 +13,12 @@ ...@@ -13,12 +13,12 @@
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <iostream> #include <iostream>
#include <vector>
#include <mkldnn.hpp> #include <mkldnn.hpp>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
static int tensor_volume(const mkldnn::memory::dims &t) static int tensor_volume(const mkldnn::memory::dims& t)
{ {
int x = 1; int x = 1;
for (const auto i : t) for (const auto i : t)
...@@ -26,7 +26,6 @@ static int tensor_volume(const mkldnn::memory::dims &t) ...@@ -26,7 +26,6 @@ static int tensor_volume(const mkldnn::memory::dims &t)
return x; return x;
} }
TEST(mkldnn, engine) TEST(mkldnn, engine)
{ {
using namespace mkldnn; using namespace mkldnn;
...@@ -39,13 +38,15 @@ TEST(mkldnn, engine) ...@@ -39,13 +38,15 @@ TEST(mkldnn, engine)
const int mb = 2; const int mb = 2;
const int groups = 2; const int groups = 2;
memory::dims input_tz = {mb, 256, 13, 13}; memory::dims input_tz = {mb, 256, 13, 13};
memory::dims weights_tz = {groups, 384/groups, 256/groups, 3, 3}; memory::dims weights_tz = {groups, 384 / groups, 256 / groups, 3, 3};
memory::dims bias_tz = {384}; memory::dims bias_tz = {384};
memory::dims strides = {1, 1}; memory::dims strides = {1, 1};
memory::dims padding = {0, 0}; memory::dims padding = {0, 0};
memory::dims output_tz = {mb, 384, memory::dims output_tz = {
(input_tz[2] + 2*padding[0] - weights_tz[3])/strides[0] + 1, mb,
(input_tz[3] + 2*padding[1] - weights_tz[4])/strides[1] + 1, 384,
(input_tz[2] + 2 * padding[0] - weights_tz[3]) / strides[0] + 1,
(input_tz[3] + 2 * padding[1] - weights_tz[4]) / strides[1] + 1,
}; };
std::vector<float> input(tensor_volume(input_tz), .0f); std::vector<float> input(tensor_volume(input_tz), .0f);
...@@ -54,7 +55,8 @@ TEST(mkldnn, engine) ...@@ -54,7 +55,8 @@ TEST(mkldnn, engine)
std::vector<float> output(tensor_volume(output_tz), .0f); std::vector<float> output(tensor_volume(output_tz), .0f);
auto c3_src_desc = memory::desc({input_tz}, memory::data_type::f32, memory::format::nchw); auto c3_src_desc = memory::desc({input_tz}, memory::data_type::f32, memory::format::nchw);
auto c3_weights_desc = memory::desc({weights_tz}, memory::data_type::f32, memory::format::goihw); auto c3_weights_desc =
memory::desc({weights_tz}, memory::data_type::f32, memory::format::goihw);
auto c3_bias_desc = memory::desc({bias_tz}, memory::data_type::f32, memory::format::x); auto c3_bias_desc = memory::desc({bias_tz}, memory::data_type::f32, memory::format::x);
auto c3_dst_desc = memory::desc({output_tz}, memory::data_type::f32, memory::format::nchw); auto c3_dst_desc = memory::desc({output_tz}, memory::data_type::f32, memory::format::nchw);
...@@ -63,11 +65,22 @@ TEST(mkldnn, engine) ...@@ -63,11 +65,22 @@ TEST(mkldnn, engine)
auto c3_bias = memory({c3_bias_desc, cpu_engine}, bias.data()); auto c3_bias = memory({c3_bias_desc, cpu_engine}, bias.data());
auto c3_dst = memory({c3_dst_desc, cpu_engine}, output.data()); auto c3_dst = memory({c3_dst_desc, cpu_engine}, output.data());
auto c3 = convolution_forward(convolution_forward::primitive_desc(convolution_forward::desc(prop_kind::forward, auto c3 = convolution_forward(convolution_forward::primitive_desc(
convolution_forward::desc(prop_kind::forward,
algorithm::convolution_direct, algorithm::convolution_direct,
c3_src_desc, c3_weights_desc, c3_bias_desc, c3_dst_desc, c3_src_desc,
strides, padding, padding, padding_kind::zero), c3_weights_desc,
cpu_engine), c3_src, c3_weights, c3_bias, c3_dst); c3_bias_desc,
c3_dst_desc,
strides,
padding,
padding,
padding_kind::zero),
cpu_engine),
c3_src,
c3_weights,
c3_bias,
c3_dst);
stream(stream::kind::eager).submit({c3}).wait(); stream(stream::kind::eager).submit({c3}).wait();
})); }));
......
...@@ -19,16 +19,16 @@ ...@@ -19,16 +19,16 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/pass/liveness.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp" #include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/log.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
...@@ -81,8 +81,6 @@ TEST(pass, liveness) ...@@ -81,8 +81,6 @@ TEST(pass, liveness)
// auto exc = ex.executor(seq_stuff); // auto exc = ex.executor(seq_stuff);
// return exc; // return exc;
// lg = LivenessGraph(exc.exop.ops) // lg = LivenessGraph(exc.exop.ops)
// lg.layout_memory() // lg.layout_memory()
......
...@@ -20,15 +20,15 @@ ...@@ -20,15 +20,15 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/propagate_types.hpp" #include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
using namespace ngraph; using namespace ngraph;
......
...@@ -12,20 +12,20 @@ ...@@ -12,20 +12,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <memory>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/function.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp" #include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/function.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
using namespace std; using namespace std;
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#include <algorithm> #include <algorithm>
#include "test_tools.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "test_tools.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -73,7 +73,8 @@ shared_ptr<Function> make_test_graph() ...@@ -73,7 +73,8 @@ shared_ptr<Function> make_test_graph()
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{}); auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
auto f0 = make_shared<Function>(r0, rt, op::Parameters{arg_0, arg_1, arg_2, arg_3, arg_4, arg_5}); auto f0 =
make_shared<Function>(r0, rt, op::Parameters{arg_0, arg_1, arg_2, arg_3, arg_4, arg_5});
return f0; return f0;
} }
...@@ -81,9 +82,6 @@ shared_ptr<Function> make_test_graph() ...@@ -81,9 +82,6 @@ shared_ptr<Function> make_test_graph()
size_t get_node_count(std::shared_ptr<Node> n) size_t get_node_count(std::shared_ptr<Node> n)
{ {
size_t node_count = 0; size_t node_count = 0;
traverse_nodes(n, [&](const Node* node) { traverse_nodes(n, [&](const Node* node) { node_count++; });
node_count++;
});
return node_count; return node_count;
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment