Commit c7b51d2d authored by Robert Kimball's avatar Robert Kimball

apply new .clang-format

parent 158de495
...@@ -27,7 +27,6 @@ namespace ngraph ...@@ -27,7 +27,6 @@ namespace ngraph
{ {
public: public:
size_t size() const { return m_size; } size_t size() const { return m_size; }
protected: protected:
size_t m_size; size_t m_size;
}; };
......
...@@ -29,7 +29,6 @@ namespace ngraph ...@@ -29,7 +29,6 @@ namespace ngraph
{ {
public: public:
BufferPos() {} BufferPos() {}
BufferPos(std::shared_ptr<Buffer> buffer, size_t offset, size_t size) BufferPos(std::shared_ptr<Buffer> buffer, size_t offset, size_t size)
: m_buffer(buffer) : m_buffer(buffer)
, m_offset(offset) , m_offset(offset)
...@@ -43,8 +42,8 @@ namespace ngraph ...@@ -43,8 +42,8 @@ namespace ngraph
protected: protected:
std::shared_ptr<Buffer> m_buffer; std::shared_ptr<Buffer> m_buffer;
size_t m_offset; size_t m_offset;
size_t m_size; size_t m_size;
}; };
} }
} }
...@@ -24,10 +24,10 @@ using ngraph::TensorViewType; ...@@ -24,10 +24,10 @@ using ngraph::TensorViewType;
DenseTensorViewLayout::DenseTensorViewLayout(const TensorView& tensor_view) DenseTensorViewLayout::DenseTensorViewLayout(const TensorView& tensor_view)
: TensorViewLayout(tensor_view) : TensorViewLayout(tensor_view)
{ {
auto tensor_view_type = tensor_view.get_tensor_view_type(); auto tensor_view_type = tensor_view.get_tensor_view_type();
Shape shape = tensor_view_type->get_shape(); Shape shape = tensor_view_type->get_shape();
m_size = ngraph::shape_size(shape); m_size = ngraph::shape_size(shape);
m_strides = ngraph::row_major_strides(shape); m_strides = ngraph::row_major_strides(shape);
} }
size_t DenseTensorViewLayout::get_index_offset(const std::vector<size_t>& indices) size_t DenseTensorViewLayout::get_index_offset(const std::vector<size_t>& indices)
......
...@@ -36,15 +36,14 @@ namespace ngraph ...@@ -36,15 +36,14 @@ namespace ngraph
DenseTensorViewLayout(const TensorView& tensor_view); DenseTensorViewLayout(const TensorView& tensor_view);
virtual size_t get_size() override { return m_size; } virtual size_t get_size() override { return m_size; }
size_t get_offset() const { return m_offset; } size_t get_offset() const { return m_offset; }
virtual size_t get_index_offset(const std::vector<size_t>& indices) override; virtual size_t get_index_offset(const std::vector<size_t>& indices) override;
const Strides& get_strides() const { return m_strides; } const Strides& get_strides() const { return m_strides; }
protected: protected:
Strides m_strides; Strides m_strides;
size_t m_offset; size_t m_offset;
size_t m_size; size_t m_size;
}; };
} }
} }
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/buffer_pos.hpp" #include "ngraph/descriptor/buffer_pos.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -41,7 +41,6 @@ namespace ngraph ...@@ -41,7 +41,6 @@ namespace ngraph
public: public:
virtual ~TensorViewLayout() {} virtual ~TensorViewLayout() {}
/// Extent of this view in buffer. /// Extent of this view in buffer.
/// ///
/// When we support non-linear buffers, this will need to be something other than size_t. /// When we support non-linear buffers, this will need to be something other than size_t.
...@@ -52,15 +51,17 @@ namespace ngraph ...@@ -52,15 +51,17 @@ namespace ngraph
/// With non-linear buffers, this will need to be something other than size_t. /// With non-linear buffers, this will need to be something other than size_t.
virtual size_t get_index_offset(const std::vector<size_t>& indices) = 0; virtual size_t get_index_offset(const std::vector<size_t>& indices) = 0;
const Shape& get_shape() const { return m_tensor_view.get_tensor_view_type()->get_shape(); } const Shape& get_shape() const
{
return m_tensor_view.get_tensor_view_type()->get_shape();
}
/// Where this view is located in the buffer. /// Where this view is located in the buffer.
const BufferPos& get_buffer_pos() const { return m_buffer_pos; } const BufferPos& get_buffer_pos() const { return m_buffer_pos; }
BufferPos& get_buffer_pos() { return m_buffer_pos; } BufferPos& get_buffer_pos() { return m_buffer_pos; }
protected: protected:
const ngraph::descriptor::TensorView& m_tensor_view; const ngraph::descriptor::TensorView& m_tensor_view;
BufferPos m_buffer_pos; BufferPos m_buffer_pos;
}; };
} }
} }
......
...@@ -18,9 +18,9 @@ using namespace ngraph; ...@@ -18,9 +18,9 @@ using namespace ngraph;
using namespace descriptor; using namespace descriptor;
PrimaryTensorView::PrimaryTensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type, PrimaryTensorView::PrimaryTensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type,
const std::string& name, const std::string& name,
bool is_output, bool is_output,
bool is_input) bool is_input)
: TensorView(tensor_view_type) : TensorView(tensor_view_type)
, m_tensor(tensor_view_type->get_element_type(), this, name, is_output, is_input) , m_tensor(tensor_view_type->get_element_type(), this, name, is_output, is_input)
{ {
......
...@@ -41,12 +41,12 @@ namespace ngraph ...@@ -41,12 +41,12 @@ namespace ngraph
/// @param is_output The view can be read from the host at the end of a computation. /// @param is_output The view can be read from the host at the end of a computation.
/// @param is_input The view can be written from the host at the beginning of a computation. /// @param is_input The view can be written from the host at the beginning of a computation.
PrimaryTensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type, PrimaryTensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type,
const std::string& name, const std::string& name,
bool is_output, bool is_output,
bool is_input); bool is_input);
virtual const Tensor& get_tensor() const override; virtual const Tensor& get_tensor() const override;
virtual Tensor& get_tensor() override; virtual Tensor& get_tensor() override;
protected: protected:
Tensor m_tensor; Tensor m_tensor;
......
...@@ -20,10 +20,10 @@ using namespace ngraph; ...@@ -20,10 +20,10 @@ using namespace ngraph;
using namespace ngraph::descriptor; using namespace ngraph::descriptor;
Tensor::Tensor(const element::Type& element_type, Tensor::Tensor(const element::Type& element_type,
PrimaryTensorView* primary_tensor_view, PrimaryTensorView* primary_tensor_view,
const std::string& name, const std::string& name,
bool is_output, bool is_output,
bool is_input) bool is_input)
: m_element_type(element_type) : m_element_type(element_type)
, m_primary_tensor_view(primary_tensor_view) , m_primary_tensor_view(primary_tensor_view)
, m_is_output{is_output} , m_is_output{is_output}
......
...@@ -44,34 +44,34 @@ private: ...@@ -44,34 +44,34 @@ private:
Tensor& operator=(const Tensor&) = delete; Tensor& operator=(const Tensor&) = delete;
Tensor(const element::Type& element_type, Tensor(const element::Type& element_type,
PrimaryTensorView* tensor_view, PrimaryTensorView* tensor_view,
const std::string& name, const std::string& name,
bool is_output, bool is_output,
bool is_input); bool is_input);
std::string get_next_view_name(); std::string get_next_view_name();
public: public:
bool is_output() const { return m_is_output; } bool is_output() const { return m_is_output; }
bool is_input() const { return m_is_input; } bool is_input() const { return m_is_input; }
bool is_persistent() const { return m_is_persistent; } bool is_persistent() const { return m_is_persistent; }
const std::string& get_name() const { return m_name; } const std::string& get_name() const { return m_name; }
size_t size() const; size_t size() const;
void set_pool_offset(size_t); void set_pool_offset(size_t);
size_t get_pool_offset() const; size_t get_pool_offset() const;
static std::string make_tensor_name(const Node* node, size_t value_index); static std::string make_tensor_name(const Node* node, size_t value_index);
protected: protected:
const element::Type& m_element_type; const element::Type& m_element_type;
PrimaryTensorView* m_primary_tensor_view; PrimaryTensorView* m_primary_tensor_view;
bool m_is_output; bool m_is_output;
bool m_is_input; bool m_is_input;
bool m_is_persistent; bool m_is_persistent;
std::string m_name; std::string m_name;
size_t m_next_view_id; size_t m_next_view_id;
size_t m_size; size_t m_size;
size_t m_pool_offset; size_t m_pool_offset;
}; };
std::ostream& operator<<(std::ostream&, const ngraph::descriptor::Tensor&); std::ostream& operator<<(std::ostream&, const ngraph::descriptor::Tensor&);
...@@ -49,7 +49,7 @@ namespace ngraph ...@@ -49,7 +49,7 @@ namespace ngraph
public: public:
virtual ~TensorView() {} virtual ~TensorView() {}
virtual const Tensor& get_tensor() const = 0; virtual const Tensor& get_tensor() const = 0;
virtual Tensor& get_tensor() = 0; virtual Tensor& get_tensor() = 0;
virtual std::shared_ptr<const ValueType> get_value_type() const override virtual std::shared_ptr<const ValueType> get_value_type() const override
{ {
...@@ -57,7 +57,6 @@ namespace ngraph ...@@ -57,7 +57,6 @@ namespace ngraph
} }
const std::string& get_name() const { return m_name; } const std::string& get_name() const { return m_name; }
std::shared_ptr<const TensorViewType> get_tensor_view_type() const std::shared_ptr<const TensorViewType> get_tensor_view_type() const
{ {
return m_tensor_view_type; return m_tensor_view_type;
...@@ -81,9 +80,9 @@ namespace ngraph ...@@ -81,9 +80,9 @@ namespace ngraph
} }
protected: protected:
std::shared_ptr<const TensorViewType> m_tensor_view_type; std::shared_ptr<const TensorViewType> m_tensor_view_type;
std::shared_ptr<layout::TensorViewLayout> m_tensor_view_layout; std::shared_ptr<layout::TensorViewLayout> m_tensor_view_layout;
std::string m_name; std::string m_name;
}; };
using TensorViewPtrs = std::vector<std::shared_ptr<TensorView>>; using TensorViewPtrs = std::vector<std::shared_ptr<TensorView>>;
......
...@@ -33,7 +33,7 @@ Tuple::Tuple(const std::vector<std::shared_ptr<ngraph::descriptor::Value>>& elem ...@@ -33,7 +33,7 @@ Tuple::Tuple(const std::vector<std::shared_ptr<ngraph::descriptor::Value>>& elem
} }
void Tuple::collect_tensor_views(std::vector<std::shared_ptr<TensorView>>& views, void Tuple::collect_tensor_views(std::vector<std::shared_ptr<TensorView>>& views,
const std::shared_ptr<Value>& value) const const std::shared_ptr<Value>& value) const
{ {
for (auto element : m_elements) for (auto element : m_elements)
{ {
......
...@@ -31,7 +31,7 @@ namespace ngraph ...@@ -31,7 +31,7 @@ namespace ngraph
Tuple(const std::vector<std::shared_ptr<ngraph::descriptor::Value>>& elements); Tuple(const std::vector<std::shared_ptr<ngraph::descriptor::Value>>& elements);
const std::shared_ptr<ngraph::TupleType> get_tuple_type() const; const std::shared_ptr<ngraph::TupleType> get_tuple_type() const;
std::shared_ptr<ngraph::TupleType> get_tuple_type(); std::shared_ptr<ngraph::TupleType> get_tuple_type();
virtual std::shared_ptr<const ValueType> get_value_type() const override virtual std::shared_ptr<const ValueType> get_value_type() const override
{ {
...@@ -42,7 +42,7 @@ namespace ngraph ...@@ -42,7 +42,7 @@ namespace ngraph
const std::shared_ptr<Value>& value) const override; const std::shared_ptr<Value>& value) const override;
protected: protected:
std::shared_ptr<ngraph::TupleType> m_tuple_type; std::shared_ptr<ngraph::TupleType> m_tuple_type;
std::vector<std::shared_ptr<ngraph::descriptor::Value>> m_elements; std::vector<std::shared_ptr<ngraph::descriptor::Value>> m_elements;
}; };
} }
......
...@@ -29,7 +29,7 @@ namespace ngraph ...@@ -29,7 +29,7 @@ namespace ngraph
{ {
public: public:
virtual ~Value() {} virtual ~Value() {}
virtual std::shared_ptr<const ngraph::ValueType> get_value_type() const = 0; virtual std::shared_ptr<const ngraph::ValueType> get_value_type() const = 0;
/// @brief helper for collecting all the tensor views in a sequence of values /// @brief helper for collecting all the tensor views in a sequence of values
/// ///
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
Function::Function(const std::shared_ptr<Node>& result, Function::Function(const std::shared_ptr<Node>& result,
const std::shared_ptr<ValueType>& result_type, const std::shared_ptr<ValueType>& result_type,
const std::vector<std::shared_ptr<op::Parameter>>& parameters) const std::vector<std::shared_ptr<op::Parameter>>& parameters)
: m_result(result) : m_result(result)
, m_parameters(parameters) , m_parameters(parameters)
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include <initializer_list> #include <initializer_list>
#include <memory> #include <memory>
#include <vector>
#include <string> #include <string>
#include <vector>
#include "ngraph/descriptor/tensor_view.hpp" #include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
...@@ -32,24 +32,21 @@ namespace ngraph ...@@ -32,24 +32,21 @@ namespace ngraph
class Function class Function
{ {
public: public:
Function(const std::shared_ptr<Node>& result, Function(const std::shared_ptr<Node>& result,
const std::shared_ptr<ValueType>& result_type, const std::shared_ptr<ValueType>& result_type,
const std::vector<std::shared_ptr<op::Parameter>>& parameters); const std::vector<std::shared_ptr<op::Parameter>>& parameters);
std::shared_ptr<Node> get_result() { return m_result; } std::shared_ptr<Node> get_result() { return m_result; }
const std::vector<std::shared_ptr<op::Parameter>> get_parameters() const const std::vector<std::shared_ptr<op::Parameter>> get_parameters() const
{ {
return m_parameters; return m_parameters;
} }
const std::shared_ptr<ValueType> get_result_type() const const std::shared_ptr<ValueType> get_result_type() const { return m_result_type; }
{
return m_result_type;
}
std::string get_name() const { return m_name; } std::string get_name() const { return m_name; }
protected: protected:
std::shared_ptr<Node> m_result; std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters; std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters;
std::string m_name; std::string m_name;
std::shared_ptr<ValueType> m_result_type; std::shared_ptr<ValueType> m_result_type;
}; };
} }
...@@ -30,12 +30,12 @@ namespace nervana ...@@ -30,12 +30,12 @@ namespace nervana
class thread_starter; class thread_starter;
} }
string nervana::logger::log_path; string nervana::logger::log_path;
deque<string> nervana::logger::queue; deque<string> nervana::logger::queue;
static mutex queue_mutex; static mutex queue_mutex;
static condition_variable queue_condition; static condition_variable queue_condition;
static unique_ptr<thread> queue_thread; static unique_ptr<thread> queue_thread;
static bool active = false; static bool active = false;
class nervana::thread_starter class nervana::thread_starter
{ {
...@@ -53,7 +53,7 @@ void nervana::logger::set_log_path(const string& path) ...@@ -53,7 +53,7 @@ void nervana::logger::set_log_path(const string& path)
void nervana::logger::start() void nervana::logger::start()
{ {
active = true; active = true;
queue_thread = unique_ptr<thread>(new thread(&thread_entry, nullptr)); queue_thread = unique_ptr<thread>(new thread(&thread_entry, nullptr));
} }
...@@ -103,8 +103,8 @@ nervana::log_helper::log_helper(LOG_TYPE type, const char* file, int line, const ...@@ -103,8 +103,8 @@ nervana::log_helper::log_helper(LOG_TYPE type, const char* file, int line, const
} }
std::time_t tt = chrono::system_clock::to_time_t(chrono::system_clock::now()); std::time_t tt = chrono::system_clock::to_time_t(chrono::system_clock::now());
auto tm = std::gmtime(&tt); auto tm = std::gmtime(&tt);
char buffer[256]; char buffer[256];
// strftime(buffer,sizeof(buffer), "%d/%b/%Y:%H:%M:%S %z", tm); // strftime(buffer,sizeof(buffer), "%d/%b/%Y:%H:%M:%S %z", tm);
// strftime(buffer,sizeof(buffer), "%Y-%m-%d %H:%M:%S UTC", tm); // strftime(buffer,sizeof(buffer), "%Y-%m-%d %H:%M:%S UTC", tm);
strftime(buffer, sizeof(buffer), "%Y-%m-%dT%H:%M:%Sz", tm); strftime(buffer, sizeof(buffer), "%Y-%m-%dT%H:%M:%Sz", tm);
......
...@@ -36,10 +36,10 @@ namespace nervana ...@@ -36,10 +36,10 @@ namespace nervana
return i < _size ? _string[i] : throw std::out_of_range(""); return i < _size ? _string[i] : throw std::out_of_range("");
} }
constexpr const char* get_ptr(size_t offset) const { return &_string[offset]; } constexpr const char* get_ptr(size_t offset) const { return &_string[offset]; }
constexpr size_t size() const { return _size; } constexpr size_t size() const { return _size; }
private: private:
const char* _string; const char* _string;
size_t _size; size_t _size;
}; };
constexpr const char* find_last(conststring s, size_t offset, char ch) constexpr const char* find_last(conststring s, size_t offset, char ch)
...@@ -84,23 +84,23 @@ namespace nervana ...@@ -84,23 +84,23 @@ namespace nervana
static void log_item(const std::string& s); static void log_item(const std::string& s);
static void process_event(const std::string& s); static void process_event(const std::string& s);
static void thread_entry(void* param); static void thread_entry(void* param);
static std::string log_path; static std::string log_path;
static std::deque<std::string> queue; static std::deque<std::string> queue;
}; };
#define NGRAPH_ERR \ #define NGRAPH_ERR \
nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_ERROR, \ nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_ERROR, \
nervana::get_file_name(__FILE__), \ nervana::get_file_name(__FILE__), \
__LINE__, \ __LINE__, \
__PRETTY_FUNCTION__) \ __PRETTY_FUNCTION__) \
.stream() .stream()
#define NGRAPH_WARN \ #define NGRAPH_WARN \
nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_WARNING, \ nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_WARNING, \
nervana::get_file_name(__FILE__), \ nervana::get_file_name(__FILE__), \
__LINE__, \ __LINE__, \
__PRETTY_FUNCTION__) \ __PRETTY_FUNCTION__) \
.stream() .stream()
#define NGRAPH_INFO \ #define NGRAPH_INFO \
nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_INFO, \ nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_INFO, \
nervana::get_file_name(__FILE__), \ nervana::get_file_name(__FILE__), \
__LINE__, \ __LINE__, \
......
...@@ -32,7 +32,9 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType> ...@@ -32,7 +32,9 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType>
} }
} }
Node::~Node() {} Node::~Node()
{
}
void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type) void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
{ {
......
...@@ -64,10 +64,8 @@ namespace ngraph ...@@ -64,10 +64,8 @@ namespace ngraph
void assign_tensors(); void assign_tensors();
const Nodes& get_arguments() const { return m_arguments; } const Nodes& get_arguments() const { return m_arguments; }
void clear_arguments() { m_arguments.clear(); } void clear_arguments() { m_arguments.clear(); }
const std::multiset<Node*>& users() const { return m_users; } const std::multiset<Node*>& users() const { return m_users; }
virtual std::string get_node_id() const; virtual std::string get_node_id() const;
/// Return true if this has the same implementing class as node. This /// Return true if this has the same implementing class as node. This
...@@ -78,9 +76,8 @@ namespace ngraph ...@@ -78,9 +76,8 @@ namespace ngraph
return typeid(*this) == typeid(*node.get()); return typeid(*this) == typeid(*node.get());
} }
std::shared_ptr<const ValueType> get_value_type() { return m_value_type; } std::shared_ptr<const ValueType> get_value_type() { return m_value_type; }
const std::shared_ptr<const ValueType> get_value_type() const { return m_value_type; } const std::shared_ptr<const ValueType> get_value_type() const { return m_value_type; }
void set_value_type(const element::Type& element_type, const Shape& shape) void set_value_type(const element::Type& element_type, const Shape& shape)
{ {
m_value_type = std::make_shared<TensorViewType>(element_type, shape); m_value_type = std::make_shared<TensorViewType>(element_type, shape);
...@@ -101,27 +98,26 @@ namespace ngraph ...@@ -101,27 +98,26 @@ namespace ngraph
bool is_output() const; bool is_output() const;
void set_is_output(); void set_is_output();
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
std::vector<descriptor::Input>& get_inputs() { return m_inputs; } std::vector<descriptor::Input>& get_inputs() { return m_inputs; }
const std::vector<descriptor::Input>& get_inputs() const { return m_inputs; } const std::vector<descriptor::Input>& get_inputs() const { return m_inputs; }
std::vector<descriptor::Output>& get_outputs() { return m_outputs; } std::vector<descriptor::Output>& get_outputs() { return m_outputs; }
const std::vector<descriptor::Output>& get_outputs() const { return m_outputs; } const std::vector<descriptor::Output>& get_outputs() const { return m_outputs; }
std::unordered_set<descriptor::Tensor*> liveness_live_list; std::unordered_set<descriptor::Tensor*> liveness_live_list;
std::unordered_set<descriptor::Tensor*> liveness_new_list; std::unordered_set<descriptor::Tensor*> liveness_new_list;
std::unordered_set<descriptor::Tensor*> liveness_free_list; std::unordered_set<descriptor::Tensor*> liveness_free_list;
protected: protected:
Nodes m_arguments; Nodes m_arguments;
std::shared_ptr<const ValueType> m_value_type; std::shared_ptr<const ValueType> m_value_type;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
std::string m_name; std::string m_name;
size_t m_instance_id; size_t m_instance_id;
static size_t m_next_instance_id; static size_t m_next_instance_id;
std::vector<descriptor::Input> m_inputs; std::vector<descriptor::Input> m_inputs;
std::vector<descriptor::Output> m_outputs; std::vector<descriptor::Output> m_outputs;
bool m_is_output; bool m_is_output;
}; };
} }
...@@ -19,8 +19,7 @@ using namespace ngraph; ...@@ -19,8 +19,7 @@ using namespace ngraph;
using namespace ngraph::op; using namespace ngraph::op;
const element::Type& BinaryElementwiseArithmetic::propagate_element_types( const element::Type& BinaryElementwiseArithmetic::propagate_element_types(
const element::Type& arg0_element_type, const element::Type& arg0_element_type, const element::Type& arg1_element_type) const
const element::Type& arg1_element_type) const
{ {
if (arg0_element_type != arg1_element_type) if (arg0_element_type != arg1_element_type)
{ {
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -41,11 +41,9 @@ void BinaryElementwiseBuiltin::propagate_types() ...@@ -41,11 +41,9 @@ void BinaryElementwiseBuiltin::propagate_types()
throw ngraph_error("Arguments must have the same tensor view shape"); throw ngraph_error("Arguments must have the same tensor view shape");
} }
const element::Type& result_element_type = const element::Type& result_element_type = propagate_element_types(
propagate_element_types(arg0_tensor_type->get_element_type(), arg0_tensor_type->get_element_type(), arg1_tensor_type->get_element_type());
arg1_tensor_type->get_element_type());
set_value_type_checked(make_shared<TensorViewType>(result_element_type, set_value_type_checked(
arg0_tensor_type->get_shape())); make_shared<TensorViewType>(result_element_type, arg0_tensor_type->get_shape()));
} }
...@@ -19,8 +19,7 @@ using namespace ngraph; ...@@ -19,8 +19,7 @@ using namespace ngraph;
using namespace ngraph::op; using namespace ngraph::op;
const element::Type& BinaryElementwiseComparison::propagate_element_types( const element::Type& BinaryElementwiseComparison::propagate_element_types(
const element::Type& arg0_element_type, const element::Type& arg0_element_type, const element::Type& arg1_element_type) const
const element::Type& arg1_element_type) const
{ {
if (arg0_element_type != arg1_element_type) if (arg0_element_type != arg1_element_type)
{ {
......
...@@ -19,7 +19,8 @@ using namespace ngraph::op; ...@@ -19,7 +19,8 @@ using namespace ngraph::op;
void Broadcast::propagate_types() void Broadcast::propagate_types()
{ {
if (m_arguments.size() != 1){ if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments."); throw ngraph_error("Wrong number of arguments.");
} }
...@@ -42,5 +43,6 @@ void Broadcast::propagate_types() ...@@ -42,5 +43,6 @@ void Broadcast::propagate_types()
{ {
throw ngraph_error("Broadcast arg, shape, and axes are incompatible"); throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
} }
set_value_type_checked(make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape)); set_value_type_checked(
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape));
} }
...@@ -28,8 +28,8 @@ namespace ngraph ...@@ -28,8 +28,8 @@ namespace ngraph
/// the remaining axes in shape must be the same as the shape of arg. /// the remaining axes in shape must be the same as the shape of arg.
/// ///
Broadcast(const std::shared_ptr<Node>& arg, Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape, const Shape& shape,
const AxisSet& broadcast_axes) const AxisSet& broadcast_axes)
: IndexBuiltin(arg) : IndexBuiltin(arg)
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
...@@ -37,12 +37,11 @@ namespace ngraph ...@@ -37,12 +37,11 @@ namespace ngraph
} }
virtual std::string description() const override { return "Broadcast"; } virtual std::string description() const override { return "Broadcast"; }
virtual void propagate_types() override; virtual void propagate_types() override;
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; } const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
protected: protected:
Shape m_shape; Shape m_shape;
AxisSet m_broadcast_axes; AxisSet m_broadcast_axes;
}; };
} }
......
...@@ -47,7 +47,7 @@ void Concat::propagate_types() ...@@ -47,7 +47,7 @@ void Concat::propagate_types()
size_t concatenation_axis_length = arg0_shape.at(m_concatenation_axis); size_t concatenation_axis_length = arg0_shape.at(m_concatenation_axis);
auto& arg0_element_type = arg0_tensor_view_type->get_element_type(); auto& arg0_element_type = arg0_tensor_view_type->get_element_type();
for(auto i = 1; i < m_arguments.size(); i++) for (auto i = 1; i < m_arguments.size(); i++)
{ {
auto argi_type = m_arguments.at(i)->get_value_type(); auto argi_type = m_arguments.at(i)->get_value_type();
if (nullptr == argi_type) if (nullptr == argi_type)
...@@ -72,11 +72,12 @@ void Concat::propagate_types() ...@@ -72,11 +72,12 @@ void Concat::propagate_types()
throw ngraph_error("Argument element types do not match"); throw ngraph_error("Argument element types do not match");
} }
for(auto j = 0; j < argi_shape.size(); j++) for (auto j = 0; j < argi_shape.size(); j++)
{ {
if (j != m_concatenation_axis && arg0_shape.at(j) != argi_shape.at(j)) if (j != m_concatenation_axis && arg0_shape.at(j) != argi_shape.at(j))
{ {
throw ngraph_error("Arguments to concat do not have same dimension on a non-concatenation axis"); throw ngraph_error(
"Arguments to concat do not have same dimension on a non-concatenation axis");
} }
else if (j == m_concatenation_axis) else if (j == m_concatenation_axis)
{ {
......
...@@ -30,17 +30,16 @@ namespace ngraph ...@@ -30,17 +30,16 @@ namespace ngraph
/// ///
/// Example: n0 has shape {2,4,2}, and n1 has shape {2,5,2}. Then the output of /// Example: n0 has shape {2,4,2}, and n1 has shape {2,5,2}. Then the output of
/// Concat(Nodes{n0,n1},1) will have shape {2,9,2}. /// Concat(Nodes{n0,n1},1) will have shape {2,9,2}.
Concat(const Nodes& args,size_t concatenation_axis) Concat(const Nodes& args, size_t concatenation_axis)
: Builtin(args) : Builtin(args)
, m_concatenation_axis(concatenation_axis) , m_concatenation_axis(concatenation_axis)
{ {
} }
virtual std::string description() const override { return "Concatenate"; } virtual std::string description() const override { return "Concatenate"; }
virtual void propagate_types() override; virtual void propagate_types() override;
size_t get_concatenation_axis() const { return m_concatenation_axis; } size_t get_concatenation_axis() const { return m_concatenation_axis; }
protected: protected:
const size_t m_concatenation_axis; const size_t m_concatenation_axis;
}; };
......
...@@ -16,7 +16,10 @@ ...@@ -16,7 +16,10 @@
using namespace ngraph::op; using namespace ngraph::op;
void ScalarConstantBase::propagate_types() {} void ScalarConstantBase::propagate_types()
{
void TensorConstantBase::propagate_types() {} }
void TensorConstantBase::propagate_types()
{
}
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include <sstream> #include <sstream>
#include "ngraph/types/element_type.hpp"
#include "ngraph/runtime/utils.hpp" #include "ngraph/runtime/utils.hpp"
#include "ngraph/types/element_type.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -60,22 +60,18 @@ namespace ngraph ...@@ -60,22 +60,18 @@ namespace ngraph
return ss.str(); return ss.str();
} }
type get_value() const type get_value() const { return m_value; }
{
return m_value;
}
protected: protected:
typename T::type m_value; typename T::type m_value;
}; };
using Float32ScalarConstant = ScalarConstant<element::Float32>; using Float32ScalarConstant = ScalarConstant<element::Float32>;
using Int8ScalarConstant = ScalarConstant<element::Int8>; using Int8ScalarConstant = ScalarConstant<element::Int8>;
using Int32ScalarConstant = ScalarConstant<element::Int32>; using Int32ScalarConstant = ScalarConstant<element::Int32>;
using Int64ScalarConstant = ScalarConstant<element::Int64>; using Int64ScalarConstant = ScalarConstant<element::Int64>;
using UInt8ScalarConstant = ScalarConstant<element::UInt8>; using UInt8ScalarConstant = ScalarConstant<element::UInt8>;
using UInt32ScalarConstant = ScalarConstant<element::UInt32>; using UInt32ScalarConstant = ScalarConstant<element::UInt32>;
using UInt64ScalarConstant = ScalarConstant<element::UInt64>; using UInt64ScalarConstant = ScalarConstant<element::UInt64>;
// Defines methods to all constant tensors // Defines methods to all constant tensors
class TensorConstantBase : public Node class TensorConstantBase : public Node
...@@ -113,18 +109,21 @@ namespace ngraph ...@@ -113,18 +109,21 @@ namespace ngraph
return ss.str(); return ss.str();
} }
typename std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> get_value() const { return m_value; } typename std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> get_value() const
{
return m_value;
}
protected: protected:
std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> m_value; std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> m_value;
}; };
using Float32TensorConstant = TensorConstant<element::Float32>; using Float32TensorConstant = TensorConstant<element::Float32>;
using Int8TensorConstant = TensorConstant<element::Int8>; using Int8TensorConstant = TensorConstant<element::Int8>;
using Int32TensorConstant = TensorConstant<element::Int32>; using Int32TensorConstant = TensorConstant<element::Int32>;
using Int64TensorConstant = TensorConstant<element::Int64>; using Int64TensorConstant = TensorConstant<element::Int64>;
using UInt8TensorConstant = TensorConstant<element::UInt8>; using UInt8TensorConstant = TensorConstant<element::UInt8>;
using UInt32TensorConstant = TensorConstant<element::UInt32>; using UInt32TensorConstant = TensorConstant<element::UInt32>;
using UInt64TensorConstant = TensorConstant<element::UInt64>; using UInt64TensorConstant = TensorConstant<element::UInt64>;
} }
} }
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
} }
virtual std::string description() const override { return "Convert"; } virtual std::string description() const override { return "Convert"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
const ngraph::element::Type& m_element_type; const ngraph::element::Type& m_element_type;
......
...@@ -34,11 +34,11 @@ void Dot::propagate_types() ...@@ -34,11 +34,11 @@ void Dot::propagate_types()
throw ngraph_error("Arguments to dot must have the same element type"); throw ngraph_error("Arguments to dot must have the same element type");
} }
vector<size_t> arg0_shape = arg0_tensor_type->get_shape(); vector<size_t> arg0_shape = arg0_tensor_type->get_shape();
vector<size_t> arg1_shape = arg1_tensor_type->get_shape(); vector<size_t> arg1_shape = arg1_tensor_type->get_shape();
size_t arg0_reduction = arg0_shape.size() - 1; size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction; size_t arg1_reduction;
const bool is_scalar_mult = arg0_shape.size() == 0 || arg1_shape.size() == 0; const bool is_scalar_mult = arg0_shape.size() == 0 || arg1_shape.size() == 0;
if (arg1_shape.size() > 1) if (arg1_shape.size() > 1)
{ {
...@@ -56,22 +56,23 @@ void Dot::propagate_types() ...@@ -56,22 +56,23 @@ void Dot::propagate_types()
vector<size_t> result_shape; vector<size_t> result_shape;
result_shape.reserve(arg0_shape.size() + arg1_shape.size() - (is_scalar_mult ? 0 : 2)); result_shape.reserve(arg0_shape.size() + arg1_shape.size() - (is_scalar_mult ? 0 : 2));
for(auto i = 0; i < arg0_shape.size(); i++) for (auto i = 0; i < arg0_shape.size(); i++)
{ {
if(is_scalar_mult || i != arg0_reduction) if (is_scalar_mult || i != arg0_reduction)
{ {
result_shape.push_back(arg0_shape[i]); result_shape.push_back(arg0_shape[i]);
} }
} }
for(auto i = 0; i < arg1_shape.size(); i++) for (auto i = 0; i < arg1_shape.size(); i++)
{ {
if(is_scalar_mult || i != arg1_reduction) if (is_scalar_mult || i != arg1_reduction)
{ {
result_shape.push_back(arg1_shape[i]); result_shape.push_back(arg1_shape[i]);
} }
} }
auto result_type = make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape); auto result_type =
make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
set_value_type_checked(result_type); set_value_type_checked(result_type);
} }
...@@ -46,7 +46,7 @@ namespace ngraph ...@@ -46,7 +46,7 @@ namespace ngraph
} }
virtual std::string description() const override { return "Dot"; } virtual std::string description() const override { return "Dot"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
} }
} }
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
/// @param function The function to be called /// @param function The function to be called
/// @param args The function arguments /// @param args The function arguments
/// ///
FunctionCall(const std::shared_ptr<Function>& function, FunctionCall(const std::shared_ptr<Function>& function,
const std::vector<std::shared_ptr<Node>>& args) const std::vector<std::shared_ptr<Node>>& args)
: Builtin(args) : Builtin(args)
, m_function(function) , m_function(function)
...@@ -36,10 +36,9 @@ namespace ngraph ...@@ -36,10 +36,9 @@ namespace ngraph
} }
virtual std::string description() const override { return "FunctionCall"; } virtual std::string description() const override { return "FunctionCall"; }
virtual void propagate_types() override; virtual void propagate_types() override;
std::shared_ptr<Function> get_function() const { return m_function; } std::shared_ptr<Function> get_function() const { return m_function; }
protected: protected:
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
}; };
......
...@@ -33,7 +33,8 @@ void GetTupleElement::propagate_types() ...@@ -33,7 +33,8 @@ void GetTupleElement::propagate_types()
throw ngraph_error("Argument must be a tuple view"); throw ngraph_error("Argument must be a tuple view");
} }
if (m_n >= arg0_tuple_type->get_element_types().size()){ if (m_n >= arg0_tuple_type->get_element_types().size())
{
throw ngraph_error("Indexing tuple beyond its size"); throw ngraph_error("Indexing tuple beyond its size");
} }
......
...@@ -31,11 +31,9 @@ namespace ngraph ...@@ -31,11 +31,9 @@ namespace ngraph
{ {
} }
virtual void propagate_types() override; virtual void propagate_types() override;
virtual std::string description() const override { return "GetTupleElement"; } virtual std::string description() const override { return "GetTupleElement"; }
size_t get_n() const { return m_n; } size_t get_n() const { return m_n; }
protected: protected:
size_t m_n; size_t m_n;
}; };
......
...@@ -31,7 +31,6 @@ namespace ngraph ...@@ -31,7 +31,6 @@ namespace ngraph
{ {
public: public:
virtual std::string description() const override { return "Builtin"; } virtual std::string description() const override { return "Builtin"; }
protected: protected:
Builtin(const std::vector<std::shared_ptr<Node>>& args) Builtin(const std::vector<std::shared_ptr<Node>>& args)
: Node(args) : Node(args)
...@@ -73,8 +72,8 @@ namespace ngraph ...@@ -73,8 +72,8 @@ namespace ngraph
: Builtin(Nodes{arg}) : Builtin(Nodes{arg})
{ {
} }
virtual const element::Type& propagate_element_types( virtual const element::Type&
const element::Type& arg_element_type) const = 0; propagate_element_types(const element::Type& arg_element_type) const = 0;
public: public:
virtual void propagate_types() override; virtual void propagate_types() override;
...@@ -87,8 +86,8 @@ namespace ngraph ...@@ -87,8 +86,8 @@ namespace ngraph
: UnaryElementwiseBuiltin({arg}) : UnaryElementwiseBuiltin({arg})
{ {
} }
virtual const element::Type& propagate_element_types( virtual const element::Type&
const element::Type& arg_element_type) const final override; propagate_element_types(const element::Type& arg_element_type) const final override;
}; };
/// Op(X, Y)[I] = op(X[I], Y[I]) /// Op(X, Y)[I] = op(X[I], Y[I])
...@@ -100,9 +99,9 @@ namespace ngraph ...@@ -100,9 +99,9 @@ namespace ngraph
: Builtin(Nodes{arg0, arg1}) : Builtin(Nodes{arg0, arg1})
{ {
} }
virtual const element::Type& propagate_element_types( virtual const element::Type&
const element::Type& arg0_element_type, propagate_element_types(const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const = 0; const element::Type& arg1_element_type) const = 0;
public: public:
virtual void propagate_types() override; virtual void propagate_types() override;
...@@ -111,34 +110,39 @@ namespace ngraph ...@@ -111,34 +110,39 @@ namespace ngraph
class BinaryElementwiseComparison : public BinaryElementwiseBuiltin class BinaryElementwiseComparison : public BinaryElementwiseBuiltin
{ {
public: public:
BinaryElementwiseComparison( BinaryElementwiseComparison(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
virtual std::string description() const override { return "BinaryElementwiseComparison"; } virtual std::string description() const override
{
return "BinaryElementwiseComparison";
}
//virtual void propagate_types() override; //virtual void propagate_types() override;
virtual const element::Type& propagate_element_types( virtual const element::Type&
const element::Type& arg0_element_type, propagate_element_types(const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const override; const element::Type& arg1_element_type) const override;
}; };
class BinaryElementwiseArithmetic : public BinaryElementwiseBuiltin class BinaryElementwiseArithmetic : public BinaryElementwiseBuiltin
{ {
public: public:
BinaryElementwiseArithmetic( BinaryElementwiseArithmetic(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
virtual std::string description() const override { return "BinaryElementwiseArithmetic"; } virtual std::string description() const override
{
return "BinaryElementwiseArithmetic";
}
//virtual void propagate_types() override; //virtual void propagate_types() override;
virtual const element::Type& propagate_element_types( virtual const element::Type& propagate_element_types(
const element::Type& arg0_element_type, const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const element::Type& arg1_element_type) const final override;
const final override;
}; };
} }
} }
...@@ -38,7 +38,9 @@ void Parameter::assign_function(Function* function, size_t index) ...@@ -38,7 +38,9 @@ void Parameter::assign_function(Function* function, size_t index)
throw ngraph_error("Re-assigning function to a parameter."); throw ngraph_error("Re-assigning function to a parameter.");
} }
m_function = function; m_function = function;
m_index = index; m_index = index;
} }
void Parameter::propagate_types() {} void Parameter::propagate_types()
{
}
...@@ -37,15 +37,15 @@ namespace ngraph ...@@ -37,15 +37,15 @@ namespace ngraph
void assign_function(Function* function, size_t index); void assign_function(Function* function, size_t index);
public: public:
Parameter(const std::shared_ptr<ValueType>& value_type=nullptr); Parameter(const std::shared_ptr<ValueType>& value_type = nullptr);
Parameter(const ngraph::element::Type& element_type, const Shape& shape); Parameter(const ngraph::element::Type& element_type, const Shape& shape);
std::string description() const override { return "Parameter"; } std::string description() const override { return "Parameter"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
Function* m_function; Function* m_function;
size_t m_index; size_t m_index;
}; };
} }
} }
...@@ -30,7 +30,8 @@ void Reduce::propagate_types() ...@@ -30,7 +30,8 @@ void Reduce::propagate_types()
{ {
throw ngraph_error("Argument to reduce is missing type."); throw ngraph_error("Argument to reduce is missing type.");
} }
auto arg_reductee_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_reductee_type); auto arg_reductee_tensor_view_type =
dynamic_pointer_cast<const TensorViewType>(arg_reductee_type);
if (nullptr == arg_reductee_tensor_view_type) if (nullptr == arg_reductee_tensor_view_type)
{ {
throw ngraph_error("Argument to reduce is not a tensor view"); throw ngraph_error("Argument to reduce is not a tensor view");
...@@ -51,7 +52,8 @@ void Reduce::propagate_types() ...@@ -51,7 +52,8 @@ void Reduce::propagate_types()
throw ngraph_error("Argument for initial value is not a scalar"); throw ngraph_error("Argument for initial value is not a scalar");
} }
if (arg_init_tensor_view_type->get_element_type() != arg_reductee_tensor_view_type->get_element_type()) if (arg_init_tensor_view_type->get_element_type() !=
arg_reductee_tensor_view_type->get_element_type())
{ {
throw ngraph_error("Element types for reductee and initial values do not match"); throw ngraph_error("Element types for reductee and initial values do not match");
} }
...@@ -99,5 +101,6 @@ void Reduce::propagate_types() ...@@ -99,5 +101,6 @@ void Reduce::propagate_types()
throw ngraph_error("Return type from reduction function does not match expected"); throw ngraph_error("Return type from reduction function does not match expected");
} }
set_value_type_checked(make_shared<TensorViewType>(arg_reductee_tensor_view_type->get_element_type(), result_shape)); set_value_type_checked(make_shared<TensorViewType>(
arg_reductee_tensor_view_type->get_element_type(), result_shape));
} }
...@@ -27,22 +27,24 @@ namespace ngraph ...@@ -27,22 +27,24 @@ namespace ngraph
/// @param reduction_function The reduction function to use. /// @param reduction_function The reduction function to use.
/// @param reduction_axes The axis positions (0-based) to be reduced. /// @param reduction_axes The axis positions (0-based) to be reduced.
/// ///
Reduce(const std::shared_ptr<Node>& arg_reductee, Reduce(const std::shared_ptr<Node>& arg_reductee,
const std::shared_ptr<Node>& arg_init, const std::shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& reduction_function, const std::shared_ptr<Function>& reduction_function,
const AxisSet& reduction_axes) const AxisSet& reduction_axes)
: Builtin({arg_reductee,arg_init}) : Builtin({arg_reductee, arg_init})
, m_reduction_function(reduction_function) , m_reduction_function(reduction_function)
, m_reduction_axes(reduction_axes) , m_reduction_axes(reduction_axes)
{ {
} }
virtual std::string description() const override { return "Reduce"; } virtual std::string description() const override { return "Reduce"; }
virtual void propagate_types() override; virtual void propagate_types() override;
std::shared_ptr<Function> get_reduction_function() const { return m_reduction_function; } std::shared_ptr<Function> get_reduction_function() const
{
return m_reduction_function;
}
const AxisSet& get_reduction_axes() const { return m_reduction_axes; } const AxisSet& get_reduction_axes() const { return m_reduction_axes; }
protected: protected:
std::shared_ptr<Function> m_reduction_function; std::shared_ptr<Function> m_reduction_function;
AxisSet m_reduction_axes; AxisSet m_reduction_axes;
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -42,8 +42,8 @@ void Select::propagate_types() ...@@ -42,8 +42,8 @@ void Select::propagate_types()
{ {
throw ngraph_error("Argument 0 for arithmetic operators must have boolean element type"); throw ngraph_error("Argument 0 for arithmetic operators must have boolean element type");
} }
if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape() if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape() ||
|| arg0_tensor_type->get_shape() != arg2_tensor_type->get_shape()) arg0_tensor_type->get_shape() != arg2_tensor_type->get_shape())
{ {
throw ngraph_error("Arguments must have the same tensor view shape"); throw ngraph_error("Arguments must have the same tensor view shape");
} }
...@@ -54,4 +54,3 @@ void Select::propagate_types() ...@@ -54,4 +54,3 @@ void Select::propagate_types()
set_value_type_checked(arg1_tensor_type); set_value_type_checked(arg1_tensor_type);
} }
...@@ -24,7 +24,7 @@ namespace ngraph ...@@ -24,7 +24,7 @@ namespace ngraph
Select(const std::shared_ptr<Node>& arg0, Select(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1, const std::shared_ptr<Node>& arg1,
const std::shared_ptr<Node>& arg2) const std::shared_ptr<Node>& arg2)
: Builtin(Nodes{arg0, arg1, arg2}) : Builtin(Nodes{arg0, arg1, arg2})
{ {
} }
virtual std::string description() const override { return "Select"; } virtual std::string description() const override { return "Select"; }
......
...@@ -27,7 +27,7 @@ namespace ngraph ...@@ -27,7 +27,7 @@ namespace ngraph
} }
virtual std::string description() const override { return "Tuple"; } virtual std::string description() const override { return "Tuple"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
} }
} }
...@@ -20,8 +20,8 @@ using namespace std; ...@@ -20,8 +20,8 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace ngraph::op; using namespace ngraph::op;
const element::Type& UnaryElementwiseArithmetic::propagate_element_types( const element::Type&
const element::Type& arg_element_type) const UnaryElementwiseArithmetic::propagate_element_types(const element::Type& arg_element_type) const
{ {
if (arg_element_type == element::Bool::element_type()) if (arg_element_type == element::Bool::element_type())
{ {
......
...@@ -37,6 +37,6 @@ void UnaryElementwiseBuiltin::propagate_types() ...@@ -37,6 +37,6 @@ void UnaryElementwiseBuiltin::propagate_types()
const element::Type& result_element_type = const element::Type& result_element_type =
propagate_element_types(arg_tensor_type->get_element_type()); propagate_element_types(arg_tensor_type->get_element_type());
set_value_type_checked(make_shared<TensorViewType>(result_element_type, set_value_type_checked(
arg_tensor_type->get_shape())); make_shared<TensorViewType>(result_element_type, arg_tensor_type->get_shape()));
} }
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include <fstream> #include <fstream>
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -51,7 +51,6 @@ bool pass::DumpSorted::run_on_call_list(list<Node*>& nodes) ...@@ -51,7 +51,6 @@ bool pass::DumpSorted::run_on_call_list(list<Node*>& nodes)
out << join(outputs); out << join(outputs);
out << "\n"; out << "\n";
for (const Tensor* tensor : node->liveness_live_list) for (const Tensor* tensor : node->liveness_live_list)
{ {
out << " L " << tensor->get_name() << "\n"; out << " L " << tensor->get_name() << "\n";
......
...@@ -35,5 +35,5 @@ public: ...@@ -35,5 +35,5 @@ public:
virtual bool run_on_call_list(std::list<Node*>&) override; virtual bool run_on_call_list(std::list<Node*>&) override;
private: private:
const std::string m_output_file; const std::string m_output_file;
}; };
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
#include <sstream> #include <sstream>
#include <unordered_set> #include <unordered_set>
#include "ngraph/log.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "ngraph/log.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -31,7 +31,7 @@ bool pass::Liveness::run_on_call_list(list<Node*>& ops) ...@@ -31,7 +31,7 @@ bool pass::Liveness::run_on_call_list(list<Node*>& ops)
{ {
unordered_set<Tensor*> currently_live; unordered_set<Tensor*> currently_live;
for(auto it=ops.rbegin(); it!=ops.rend(); it++) for (auto it = ops.rbegin(); it != ops.rend(); it++)
{ {
Node* node = *it; Node* node = *it;
node->liveness_live_list.clear(); node->liveness_live_list.clear();
...@@ -143,13 +143,10 @@ void pass::Liveness::check_dependencies( ...@@ -143,13 +143,10 @@ void pass::Liveness::check_dependencies(
bool pass::Liveness::is_temporary(const Tensor& tensor) bool pass::Liveness::is_temporary(const Tensor& tensor)
{ {
return return tensor.is_persistent() == false && tensor.is_input() == false &&
tensor.is_persistent() == false tensor.is_output() == false;
&& tensor.is_input() == false // && tensor.is_constant() == false
&& tensor.is_output() == false // && tensor.is_compile_only() == false;
;
// && tensor.is_constant() == false
// && tensor.is_compile_only() == false;
} }
void pass::Liveness::validate_liveness(const list<Node*>& ops) void pass::Liveness::validate_liveness(const list<Node*>& ops)
...@@ -170,4 +167,3 @@ void pass::Liveness::validate_liveness(const list<Node*>& ops) ...@@ -170,4 +167,3 @@ void pass::Liveness::validate_liveness(const list<Node*>& ops)
dead_tensors.insert(node->liveness_free_list.begin(), node->liveness_free_list.end()); dead_tensors.insert(node->liveness_free_list.begin(), node->liveness_free_list.end());
} }
} }
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#pragma once #pragma once
#include "ngraph/pass/call_pass.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/pass/call_pass.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include "ngraph/function.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/function.hpp" #include "ngraph/pass/manager.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include <vector>
#include <memory>
#include <list> #include <list>
#include <memory>
#include <vector>
#include "ngraph/pass/call_pass.hpp" #include "ngraph/pass/call_pass.hpp"
#include "ngraph/pass/tree_pass.hpp" #include "ngraph/pass/tree_pass.hpp"
...@@ -37,17 +37,17 @@ class ngraph::pass::ManagerState ...@@ -37,17 +37,17 @@ class ngraph::pass::ManagerState
{ {
public: public:
Function* get_function(); Function* get_function();
void set_function(Function*); void set_function(Function*);
size_t get_temporary_pool_size(); size_t get_temporary_pool_size();
void set_temporary_pool_size(size_t); void set_temporary_pool_size(size_t);
std::list<Node*>& get_call_graph(); std::list<Node*>& get_call_graph();
const std::list<Node*>& get_call_graph() const; const std::list<Node*>& get_call_graph() const;
private: private:
Function* m_function = nullptr; Function* m_function = nullptr;
size_t m_temporary_pool_size = 0; size_t m_temporary_pool_size = 0;
std::list<Node*> m_call_graph; std::list<Node*> m_call_graph;
}; };
...@@ -59,7 +59,7 @@ public: ...@@ -59,7 +59,7 @@ public:
void initialize_default_passes(); void initialize_default_passes();
template<typename T, class... Args> template <typename T, class... Args>
void register_pass(Args... args) void register_pass(Args... args)
{ {
static_assert(std::is_base_of<pass::Base, T>::value, "pass not derived from pass base"); static_assert(std::is_base_of<pass::Base, T>::value, "pass not derived from pass base");
...@@ -86,5 +86,5 @@ private: ...@@ -86,5 +86,5 @@ private:
std::vector<std::shared_ptr<TreeBase>> m_tree_passes; std::vector<std::shared_ptr<TreeBase>> m_tree_passes;
std::vector<std::shared_ptr<CallBase>> m_call_passes; std::vector<std::shared_ptr<CallBase>> m_call_passes;
ManagerState m_state; ManagerState m_state;
}; };
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
#include <exception> #include <exception>
#include <sstream> #include <sstream>
#include "ngraph/log.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp" #include "ngraph/pass/memory_layout.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -69,7 +69,6 @@ pass::MemoryManager::node::node(size_t size, block_state state) ...@@ -69,7 +69,6 @@ pass::MemoryManager::node::node(size_t size, block_state state)
: m_size{size} : m_size{size}
, m_state{state} , m_state{state}
{ {
} }
pass::MemoryManager::MemoryManager(size_t alignment) pass::MemoryManager::MemoryManager(size_t alignment)
...@@ -84,14 +83,10 @@ pass::MemoryManager::MemoryManager(size_t alignment) ...@@ -84,14 +83,10 @@ pass::MemoryManager::MemoryManager(size_t alignment)
size_t pass::MemoryManager::allocate(size_t size) size_t pass::MemoryManager::allocate(size_t size)
{ {
size_t rc; size_t rc;
switch(m_scheme) switch (m_scheme)
{ {
case allocation_scheme::FIRST_FIT: case allocation_scheme::FIRST_FIT: rc = first_fit(size); break;
rc = first_fit(size); case allocation_scheme::BEST_FIT: rc = best_fit(size); break;
break;
case allocation_scheme::BEST_FIT:
rc = best_fit(size);
break;
} }
return rc; return rc;
} }
...@@ -103,7 +98,7 @@ size_t pass::MemoryManager::best_fit(size_t size) ...@@ -103,7 +98,7 @@ size_t pass::MemoryManager::best_fit(size_t size)
size_t min_delta = numeric_limits<size_t>::max(); size_t min_delta = numeric_limits<size_t>::max();
auto best_fit = m_node_list.end(); auto best_fit = m_node_list.end();
size_t best_offset = offset; size_t best_offset = offset;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it) for (auto it = m_node_list.begin(); it != m_node_list.end(); ++it)
{ {
if (it->m_state == block_state::FREE && it->m_size >= size) if (it->m_state == block_state::FREE && it->m_size >= size)
{ {
...@@ -143,7 +138,7 @@ size_t pass::MemoryManager::first_fit(size_t size) ...@@ -143,7 +138,7 @@ size_t pass::MemoryManager::first_fit(size_t size)
size = align(size, m_alignment); size = align(size, m_alignment);
size_t offset = 0; size_t offset = 0;
bool found = false; bool found = false;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it) for (auto it = m_node_list.begin(); it != m_node_list.end(); ++it)
{ {
if (it->m_state == block_state::FREE && it->m_size >= size) if (it->m_state == block_state::FREE && it->m_size >= size)
{ {
...@@ -176,7 +171,7 @@ void pass::MemoryManager::free(size_t offset) ...@@ -176,7 +171,7 @@ void pass::MemoryManager::free(size_t offset)
{ {
size_t search_offset = 0; size_t search_offset = 0;
bool found = false; bool found = false;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it) for (auto it = m_node_list.begin(); it != m_node_list.end(); ++it)
{ {
if (offset == search_offset) if (offset == search_offset)
{ {
......
...@@ -62,12 +62,11 @@ public: ...@@ -62,12 +62,11 @@ public:
node(size_t size, block_state state); node(size_t size, block_state state);
bool is_free() const { return m_state == block_state::FREE; } bool is_free() const { return m_state == block_state::FREE; }
size_t m_size;
size_t m_size; block_state m_state;
block_state m_state;
}; };
MemoryManager(size_t alignment=1); MemoryManager(size_t alignment = 1);
// memory_manager& alignment(size_t a); // memory_manager& alignment(size_t a);
size_t allocate(size_t size); size_t allocate(size_t size);
...@@ -81,17 +80,14 @@ public: ...@@ -81,17 +80,14 @@ public:
std::list<node>::iterator end() { return m_node_list.end(); } std::list<node>::iterator end() { return m_node_list.end(); }
std::list<node>::const_iterator begin() const { return m_node_list.cbegin(); } std::list<node>::const_iterator begin() const { return m_node_list.cbegin(); }
std::list<node>::const_iterator end() const { return m_node_list.cend(); } std::list<node>::const_iterator end() const { return m_node_list.cend(); }
const std::list<node>& get_node_list() const { return m_node_list; } const std::list<node>& get_node_list() const { return m_node_list; }
size_t max_allocated() const { return m_max_allocated; } size_t max_allocated() const { return m_max_allocated; }
private: private:
size_t first_fit(size_t size); size_t first_fit(size_t size);
size_t best_fit(size_t size); size_t best_fit(size_t size);
std::list<node> m_node_list; std::list<node> m_node_list;
size_t m_alignment; size_t m_alignment;
allocation_scheme m_scheme; allocation_scheme m_scheme;
size_t m_max_allocated; size_t m_max_allocated;
}; };
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <algorithm>
#include <fstream> #include <fstream>
#include <unordered_set>
#include <unordered_map> #include <unordered_map>
#include <algorithm> #include <unordered_set>
#include "memory_visualize.hpp" #include "memory_visualize.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
...@@ -154,8 +154,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>& ...@@ -154,8 +154,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>&
} }
i++; i++;
} }
sort(tensor_set.begin(), tensor_set.end(), [](const Tensor* t1, const Tensor* t2) sort(tensor_set.begin(), tensor_set.end(), [](const Tensor* t1, const Tensor* t2) {
{
return t1->size() < t2->size(); return t1->size() < t2->size();
}); });
for (const Tensor* tensor : tensor_set) for (const Tensor* tensor : tensor_set)
...@@ -206,12 +205,16 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod ...@@ -206,12 +205,16 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
y += line_spacing; y += line_spacing;
size_t x1 = offset; size_t x1 = offset;
size_t x2 = ((usage / memory_footprint) * scale) + offset; size_t x2 = ((usage / memory_footprint) * scale) + offset;
file << "<text x=\"" << 0 << "\" y=\"" << y + text_offset << "\" fill=\"" << "black" << "\">" << node->get_node_id() << "</text>\n"; file << "<text x=\"" << 0 << "\" y=\"" << y + text_offset << "\" fill=\""
file << "<line x1=\"" << x1 << "\" y1=\"" << y << "\" x2=\"" << x2 << "\" y2=\"" << y << "\""; << "black"
<< "\">" << node->get_node_id() << "</text>\n";
file << "<line x1=\"" << x1 << "\" y1=\"" << y << "\" x2=\"" << x2 << "\" y2=\"" << y
<< "\"";
file << " style=\"stroke:forestgreen;stroke-width:" << stroke_width << "\" />\n"; file << " style=\"stroke:forestgreen;stroke-width:" << stroke_width << "\" />\n";
x1 = x2; x1 = x2;
x2 = ((footprint / memory_footprint) * scale) + offset; x2 = ((footprint / memory_footprint) * scale) + offset;
file << "<line x1=\"" << x1 << "\" y1=\"" << y << "\" x2=\"" << x2 << "\" y2=\"" << y << "\""; file << "<line x1=\"" << x1 << "\" y1=\"" << y << "\" x2=\"" << x2 << "\" y2=\"" << y
<< "\"";
file << " style=\"stroke:firebrick;stroke-width:" << stroke_width << "\" />\n"; file << " style=\"stroke:firebrick;stroke-width:" << stroke_width << "\" />\n";
} }
file << "</svg>\n"; file << "</svg>\n";
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include <iostream>
#include <limits> #include <limits>
#include <list> #include <list>
#include <iostream>
#include "ngraph/pass/call_pass.hpp" #include "ngraph/pass/call_pass.hpp"
...@@ -47,6 +47,6 @@ private: ...@@ -47,6 +47,6 @@ private:
static size_t memory_usage(const Node*); static size_t memory_usage(const Node*);
static size_t memory_footprint(const Node*); static size_t memory_footprint(const Node*);
static size_t memory_footprint(const std::list<Node*>&); static size_t memory_footprint(const std::list<Node*>&);
const std::string m_filename; const std::string m_filename;
}; };
...@@ -27,6 +27,7 @@ namespace ngraph ...@@ -27,6 +27,7 @@ namespace ngraph
class ngraph::pass::Base class ngraph::pass::Base
{ {
friend class Manager; friend class Manager;
public: public:
protected: protected:
ManagerState& get_state(); ManagerState& get_state();
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include <fstream> #include <fstream>
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -24,8 +24,7 @@ using namespace std; ...@@ -24,8 +24,7 @@ using namespace std;
bool pass::VisualizeTree::run_on_tree(std::shared_ptr<Node> base_node) bool pass::VisualizeTree::run_on_tree(std::shared_ptr<Node> base_node)
{ {
// map<size_t, list<node_ptr>> dependent_nodes; // map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(base_node, [&](Node* node) traverse_nodes(base_node, [&](Node* node) {
{
for (auto arg : node->get_arguments()) for (auto arg : node->get_arguments())
{ {
m_ss << add_attributes(arg.get()); m_ss << add_attributes(arg.get());
...@@ -73,7 +72,7 @@ std::string pass::VisualizeTree::get_attributes(const Node* node) ...@@ -73,7 +72,7 @@ std::string pass::VisualizeTree::get_attributes(const Node* node)
void pass::VisualizeTree::render() const void pass::VisualizeTree::render() const
{ {
#ifdef GRAPHVIZ_FOUND #ifdef GRAPHVIZ_FOUND
auto tmp_file = m_name + ".tmp"; auto tmp_file = m_name + ".tmp";
ofstream out(tmp_file); ofstream out(tmp_file);
if (out) if (out)
{ {
...@@ -84,7 +83,7 @@ void pass::VisualizeTree::render() const ...@@ -84,7 +83,7 @@ void pass::VisualizeTree::render() const
stringstream ss; stringstream ss;
ss << "dot -Tpng " << tmp_file << " -o " << m_name; ss << "dot -Tpng " << tmp_file << " -o " << m_name;
auto cmd = ss.str(); auto cmd = ss.str();
auto stream = popen(cmd.c_str(), "r"); auto stream = popen(cmd.c_str(), "r");
pclose(stream); pclose(stream);
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include <set>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <set>
#include "ngraph/pass/tree_pass.hpp" #include "ngraph/pass/tree_pass.hpp"
...@@ -40,7 +40,7 @@ private: ...@@ -40,7 +40,7 @@ private:
std::string get_attributes(const Node* node); std::string get_attributes(const Node* node);
void render() const; void render() const;
std::stringstream m_ss; std::stringstream m_ss;
std::string m_name; std::string m_name;
std::set<const Node*> m_nodes_with_attributes; std::set<const Node*> m_nodes_with_attributes;
}; };
...@@ -20,10 +20,10 @@ using namespace std; ...@@ -20,10 +20,10 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace ngraph::runtime; using namespace ngraph::runtime;
CallFrame::CallFrame(size_t n_inputs, CallFrame::CallFrame(size_t n_inputs,
size_t n_outputs, size_t n_outputs,
const TensorViewPtrs& temps, const TensorViewPtrs& temps,
size_t initial_pc, size_t initial_pc,
const shared_ptr<vector<shared_ptr<Instruction>>>& instructions) const shared_ptr<vector<shared_ptr<Instruction>>>& instructions)
: m_n_inputs(n_inputs) : m_n_inputs(n_inputs)
...@@ -42,10 +42,10 @@ void CallFrame::tensor_call( ...@@ -42,10 +42,10 @@ void CallFrame::tensor_call(
copy(inputs.begin(), inputs.end(), m_tensor_views.begin()); copy(inputs.begin(), inputs.end(), m_tensor_views.begin());
copy(outputs.begin(), outputs.end(), m_tensor_views.begin() + m_n_inputs); copy(outputs.begin(), outputs.end(), m_tensor_views.begin() + m_n_inputs);
m_next_pc = m_initial_pc; m_next_pc = m_initial_pc;
m_return = false; m_return = false;
while (!m_return) while (!m_return)
{ {
m_pc = m_next_pc; m_pc = m_next_pc;
m_next_pc = m_pc + 1; m_next_pc = m_pc + 1;
m_instructions->at(m_pc)->execute(*this); m_instructions->at(m_pc)->execute(*this);
} }
......
...@@ -32,10 +32,10 @@ namespace ngraph ...@@ -32,10 +32,10 @@ namespace ngraph
{ {
public: public:
CallFrame( CallFrame(
size_t n_inputs, size_t n_inputs,
size_t n_outputs, size_t n_outputs,
const TensorViewPtrs& temps, const TensorViewPtrs& temps,
size_t initial_pc, size_t initial_pc,
const std::shared_ptr<std::vector<std::shared_ptr<Instruction>>>& instructions); const std::shared_ptr<std::vector<std::shared_ptr<Instruction>>>& instructions);
/// @brief Invoke the function with values matching the signature of the function. /// @brief Invoke the function with values matching the signature of the function.
...@@ -48,32 +48,28 @@ namespace ngraph ...@@ -48,32 +48,28 @@ namespace ngraph
void tensor_call(const TensorViewPtrs& inputs, const TensorViewPtrs& outputs); void tensor_call(const TensorViewPtrs& inputs, const TensorViewPtrs& outputs);
void set_return() { m_return = true; } void set_return() { m_return = true; }
std::shared_ptr<TensorView> get_tensor_view(size_t i) { return m_tensor_views[i]; } std::shared_ptr<TensorView> get_tensor_view(size_t i) { return m_tensor_views[i]; }
template <typename ET> template <typename ET>
ParameterizedTensorView<ET>* get_parameterized_tensor_view(size_t i) ParameterizedTensorView<ET>* get_parameterized_tensor_view(size_t i)
{ {
return m_tensor_views[i]->get_parameterized_tensor_view<ET>(); return m_tensor_views[i]->get_parameterized_tensor_view<ET>();
} }
template<typename ET> template <typename ET>
typename ET::type* get_tensor_view_data(size_t i) typename ET::type* get_tensor_view_data(size_t i)
{ {
return &get_parameterized_tensor_view<ET>(i)->get_vector()[0]; return &get_parameterized_tensor_view<ET>(i)->get_vector()[0];
} }
protected: protected:
size_t m_n_inputs; size_t m_n_inputs;
size_t m_n_outputs; size_t m_n_outputs;
TensorViewPtrs m_tensor_views; TensorViewPtrs m_tensor_views;
size_t m_initial_pc; size_t m_initial_pc;
std::shared_ptr<std::vector<std::shared_ptr<Instruction>>> m_instructions; std::shared_ptr<std::vector<std::shared_ptr<Instruction>>> m_instructions;
size_t m_pc; size_t m_pc;
size_t m_next_pc; size_t m_next_pc;
bool m_return; bool m_return;
}; };
} }
} }
...@@ -38,7 +38,8 @@ namespace ngraph ...@@ -38,7 +38,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = Eigen::abs(EigenArray1d<ET>(call_frame, m_arg)); EigenArray1d<ET>(call_frame, m_out) =
Eigen::abs(EigenArray1d<ET>(call_frame, m_arg));
} }
protected: protected:
......
...@@ -29,8 +29,7 @@ namespace ngraph ...@@ -29,8 +29,7 @@ namespace ngraph
class BroadcastScalarInstruction : public Instruction class BroadcastScalarInstruction : public Instruction
{ {
public: public:
BroadcastScalarInstruction(const TensorViewInfo& arg, BroadcastScalarInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -42,7 +41,7 @@ namespace ngraph ...@@ -42,7 +41,7 @@ namespace ngraph
// pull it out as a vector. This works because of the way // pull it out as a vector. This works because of the way
// fmt::V computes sizes---it lumps together any higher // fmt::V computes sizes---it lumps together any higher
// dimensions---while fmt::M ignores them. // dimensions---while fmt::M ignores them.
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg)(0,0); EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg)(0, 0);
} }
protected: protected:
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/eigen/utils.hpp" #include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/instruction.hpp" #include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
...@@ -29,7 +29,9 @@ namespace ngraph ...@@ -29,7 +29,9 @@ namespace ngraph
class CallInstruction : public Instruction class CallInstruction : public Instruction
{ {
public: public:
CallInstruction(std::shared_ptr<ExternalFunction> ef,std::vector<TensorViewInfo> in, std::vector<TensorViewInfo> out) CallInstruction(std::shared_ptr<ExternalFunction> ef,
std::vector<TensorViewInfo> in,
std::vector<TensorViewInfo> out)
: m_external_function(ef) : m_external_function(ef)
, m_in(in) , m_in(in)
, m_out(out) , m_out(out)
...@@ -39,7 +41,7 @@ namespace ngraph ...@@ -39,7 +41,7 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
std::shared_ptr<CallFrame> cf = m_external_function->make_call_frame(); std::shared_ptr<CallFrame> cf = m_external_function->make_call_frame();
std::vector<std::shared_ptr<ngraph::runtime::Value>> inputs; std::vector<std::shared_ptr<ngraph::runtime::Value>> inputs;
std::vector<std::shared_ptr<ngraph::runtime::Value>> outputs; std::vector<std::shared_ptr<ngraph::runtime::Value>> outputs;
...@@ -51,7 +53,7 @@ namespace ngraph ...@@ -51,7 +53,7 @@ namespace ngraph
{ {
outputs.push_back(call_frame.get_tensor_view(out.get_index())); outputs.push_back(call_frame.get_tensor_view(out.get_index()));
} }
(*cf)(inputs,outputs); (*cf)(inputs, outputs);
} }
protected: protected:
......
...@@ -31,8 +31,8 @@ namespace ngraph ...@@ -31,8 +31,8 @@ namespace ngraph
{ {
public: public:
ConcatMatrixInstruction(const std::vector<TensorViewInfo>& args, ConcatMatrixInstruction(const std::vector<TensorViewInfo>& args,
size_t axis, size_t axis,
const TensorViewInfo& out) const TensorViewInfo& out)
: m_args(args) : m_args(args)
, m_axis(axis) , m_axis(axis)
, m_out(out) , m_out(out)
...@@ -59,9 +59,9 @@ namespace ngraph ...@@ -59,9 +59,9 @@ namespace ngraph
} }
protected: protected:
std::vector<TensorViewInfo> m_args; std::vector<TensorViewInfo> m_args;
size_t m_axis; size_t m_axis;
TensorViewInfo m_out; TensorViewInfo m_out;
std::vector<std::vector<size_t>> m_blocks; std::vector<std::vector<size_t>> m_blocks;
}; };
} }
......
...@@ -31,7 +31,7 @@ namespace ngraph ...@@ -31,7 +31,7 @@ namespace ngraph
{ {
public: public:
ConcatVectorInstruction(const std::vector<TensorViewInfo>& args, ConcatVectorInstruction(const std::vector<TensorViewInfo>& args,
const TensorViewInfo& out) const TensorViewInfo& out)
: m_args(args) : m_args(args)
, m_out(out) , m_out(out)
{ {
...@@ -46,15 +46,17 @@ namespace ngraph ...@@ -46,15 +46,17 @@ namespace ngraph
{ {
EigenVector<ET> out(call_frame, m_out); EigenVector<ET> out(call_frame, m_out);
size_t concat_pos = 0; size_t concat_pos = 0;
for (size_t i = 0; i < m_args.size(); i++){ for (size_t i = 0; i < m_args.size(); i++)
out.segment(concat_pos, m_sizes[i]) << EigenVector<ET>(call_frame, m_args.at(i)); {
out.segment(concat_pos, m_sizes[i])
<< EigenVector<ET>(call_frame, m_args.at(i));
concat_pos += m_sizes[i]; concat_pos += m_sizes[i];
} }
} }
protected: protected:
std::vector<TensorViewInfo> m_args; std::vector<TensorViewInfo> m_args;
TensorViewInfo m_out; TensorViewInfo m_out;
std::vector<size_t> m_sizes; std::vector<size_t> m_sizes;
}; };
} }
......
...@@ -30,7 +30,8 @@ namespace ngraph ...@@ -30,7 +30,8 @@ namespace ngraph
class ConstantInstruction : public Instruction class ConstantInstruction : public Instruction
{ {
public: public:
ConstantInstruction(const std::vector<typename ET::type> value, const TensorViewInfo& out) ConstantInstruction(const std::vector<typename ET::type> value,
const TensorViewInfo& out)
: m_value(value) : m_value(value)
, m_out(out) , m_out(out)
{ {
...@@ -38,12 +39,13 @@ namespace ngraph ...@@ -38,12 +39,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
call_frame.get_parameterized_tensor_view<ET>(m_out.get_index())->get_vector() = m_value; call_frame.get_parameterized_tensor_view<ET>(m_out.get_index())->get_vector() =
m_value;
} }
protected: protected:
const std::vector<typename ET::type> m_value; const std::vector<typename ET::type> m_value;
TensorViewInfo m_out; TensorViewInfo m_out;
}; };
} }
} }
......
...@@ -40,8 +40,9 @@ namespace ngraph ...@@ -40,8 +40,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) << EigenArray1d<ET>(call_frame, m_out)
EigenVector<ET>(call_frame, m_arg0).dot(EigenVector<ET>(call_frame, m_arg1)); << EigenVector<ET>(call_frame, m_arg0)
.dot(EigenVector<ET>(call_frame, m_arg1));
} }
protected: protected:
......
...@@ -29,7 +29,7 @@ namespace ngraph ...@@ -29,7 +29,7 @@ namespace ngraph
void less_than(TI arg0, TI arg1, TO out) void less_than(TI arg0, TI arg1, TO out)
{ {
auto result_as_float = get_map_array(&*arg0) < get_map_array(&*arg1); auto result_as_float = get_map_array(&*arg0) < get_map_array(&*arg1);
auto result_as_char = result_as_float.template cast<char>(); auto result_as_char = result_as_float.template cast<char>();
set_map_array(&*out, result_as_char); set_map_array(&*out, result_as_char);
} }
......
...@@ -37,7 +37,8 @@ namespace ngraph ...@@ -37,7 +37,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = Eigen::log(EigenArray1d<ET, fmt::V>(call_frame, m_arg)); EigenArray1d<ET, fmt::V>(call_frame, m_out) =
Eigen::log(EigenArray1d<ET, fmt::V>(call_frame, m_arg));
} }
protected: protected:
......
...@@ -27,7 +27,6 @@ namespace ngraph ...@@ -27,7 +27,6 @@ namespace ngraph
{ {
public: public:
ReturnInstruction() {} ReturnInstruction() {}
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
call_frame.set_return(); call_frame.set_return();
......
...@@ -45,8 +45,8 @@ namespace ngraph ...@@ -45,8 +45,8 @@ namespace ngraph
// fmt::V computes sizes---it lumps together any higher // fmt::V computes sizes---it lumps together any higher
// dimensions---while fmt::M ignores them. // dimensions---while fmt::M ignores them.
EigenVector<ET>(call_frame, m_out) = EigenVector<ET>(call_frame, m_out) =
call_frame.get_tensor_view_data<ET>(m_arg0.get_index())[0] call_frame.get_tensor_view_data<ET>(m_arg0.get_index())[0] *
* EigenVector<ET>(call_frame, m_arg1); EigenVector<ET>(call_frame, m_arg1);
} }
protected: protected:
......
...@@ -31,7 +31,7 @@ namespace ngraph ...@@ -31,7 +31,7 @@ namespace ngraph
namespace eigen namespace eigen
{ {
using DynamicStrides = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>; using DynamicStrides = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
using VectorStrides = Eigen::Stride<Eigen::Dynamic, 1>; using VectorStrides = Eigen::Stride<Eigen::Dynamic, 1>;
template <typename ET> template <typename ET>
using DynamicArray = Eigen::Array<typename ET::type, Eigen::Dynamic, Eigen::Dynamic>; using DynamicArray = Eigen::Array<typename ET::type, Eigen::Dynamic, Eigen::Dynamic>;
...@@ -40,7 +40,8 @@ namespace ngraph ...@@ -40,7 +40,8 @@ namespace ngraph
using EigenArrayBase = Eigen::Map<DynamicArray<ET>, 0, DynamicStrides>; using EigenArrayBase = Eigen::Map<DynamicArray<ET>, 0, DynamicStrides>;
template <typename ET> template <typename ET>
using DynamicMatrix = Eigen::Matrix<typename ET::type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; using DynamicMatrix =
Eigen::Matrix<typename ET::type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
template <typename ET> template <typename ET>
using EigenMatrixBase = Eigen::Map<DynamicMatrix<ET>, 0, DynamicStrides>; using EigenMatrixBase = Eigen::Map<DynamicMatrix<ET>, 0, DynamicStrides>;
......
...@@ -81,7 +81,7 @@ using namespace ngraph::runtime; ...@@ -81,7 +81,7 @@ using namespace ngraph::runtime;
using ngraph::descriptor::layout::DenseTensorViewLayout; using ngraph::descriptor::layout::DenseTensorViewLayout;
ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& function, ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function) bool release_function)
: m_function(function) : m_function(function)
, m_release_function(release_function) , m_release_function(release_function)
, m_is_compiled(false) , m_is_compiled(false)
...@@ -89,30 +89,31 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -89,30 +89,31 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
{ {
} }
#define REGISTER_TO_OP_MAP(op_class) \ #define REGISTER_TO_OP_MAP(op_class) \
op_map[type_index(typeid(op_class))] = [](const Node* n, \ op_map[type_index(typeid(op_class))] = [](const Node* n, \
ExternalFunction* ef, \ ExternalFunction* ef, \
FunctionMap& function_map, \ FunctionMap& function_map, \
const std::vector<TensorViewInfo>& in, \ const std::vector<TensorViewInfo>& in, \
const std::vector<TensorViewInfo>& out) const std::vector<TensorViewInfo>& out)
#define REGISTER_INSTRUCTION(op_class, instr_class, ...) \ #define REGISTER_INSTRUCTION(op_class, instr_class, ...) \
REGISTER_TO_OP_MAP(op_class) { \ REGISTER_TO_OP_MAP(op_class) \
ef->get_instructions()->push_back(make_shared<instr_class>(__VA_ARGS__)); \ { \
ef->get_instructions()->push_back(make_shared<instr_class>(__VA_ARGS__)); \
} }
// Versions the include the descriptor // Versions the include the descriptor
#define REGISTER_UNOP(op_class, instr_class) \ #define REGISTER_UNOP(op_class, instr_class) \
REGISTER_INSTRUCTION(op_class, instr_class, in[0], out[0]) REGISTER_INSTRUCTION(op_class, instr_class, in[0], out[0])
#define REGISTER_BINOP(op_class, instr_class) \ #define REGISTER_BINOP(op_class, instr_class) \
REGISTER_INSTRUCTION(op_class, instr_class, in[0], in[1], out[0]) REGISTER_INSTRUCTION(op_class, instr_class, in[0], in[1], out[0])
#define REGISTER_TERNOP(op_class, instr_class) \ #define REGISTER_TERNOP(op_class, instr_class) \
REGISTER_INSTRUCTION(op_class, instr_class, in[0], in[1], in[2], out[0]) REGISTER_INSTRUCTION(op_class, instr_class, in[0], in[1], in[2], out[0])
// Define code generators for handled ops. // Define code generators for handled ops.
ExternalFunction::OpMap& ExternalFunction::get_op_map() ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
static bool initialized = false; static bool initialized = false;
static OpMap op_map; static OpMap op_map;
if (!initialized) if (!initialized)
{ {
...@@ -146,15 +147,15 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -146,15 +147,15 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
auto broadcast = static_cast<const op::Broadcast*>(n); auto broadcast = static_cast<const op::Broadcast*>(n);
auto arg_tensor_type = auto arg_tensor_type = dynamic_pointer_cast<const TensorViewType>(
dynamic_pointer_cast<const TensorViewType>(n->get_arguments().at(0)->get_value_type()); n->get_arguments().at(0)->get_value_type());
assert(nullptr != arg_tensor_type); assert(nullptr != arg_tensor_type);
auto result_tensor_type = auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type()); dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type); assert(nullptr != result_tensor_type);
auto arg_shape = arg_tensor_type->get_shape(); auto arg_shape = arg_tensor_type->get_shape();
auto result_shape = result_tensor_type->get_shape(); auto result_shape = result_tensor_type->get_shape();
if (broadcast->get_broadcast_axes().empty()) if (broadcast->get_broadcast_axes().empty())
...@@ -175,18 +176,22 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -175,18 +176,22 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
if (broadcast->get_broadcast_axes() == AxisSet{1}) if (broadcast->get_broadcast_axes() == AxisSet{1})
{ {
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::BroadcastVectorColwiseInstruction<element::Float32>>( make_shared<
runtime::eigen::BroadcastVectorColwiseInstruction<element::Float32>>(
in[0], out[0])); in[0], out[0]));
} }
else if (broadcast->get_broadcast_axes() == AxisSet{0}) else if (broadcast->get_broadcast_axes() == AxisSet{0})
{ {
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::BroadcastVectorRowwiseInstruction<element::Float32>>( make_shared<
runtime::eigen::BroadcastVectorRowwiseInstruction<element::Float32>>(
in[0], out[0])); in[0], out[0]));
} }
else else
{ {
throw ngraph_error("Internal error: axis set for vector-matrix broadcast is neither {0} or {1}"); throw ngraph_error(
"Internal error: axis set for vector-matrix broadcast is neither {0} or "
"{1}");
} }
} }
else else
...@@ -206,8 +211,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -206,8 +211,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
if (result_shape.size() == 1) if (result_shape.size() == 1)
{ {
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::ConcatVectorInstruction<element::Float32>>( make_shared<runtime::eigen::ConcatVectorInstruction<element::Float32>>(in,
in, out[0])); out[0]));
} }
else if (result_shape.size() == 2) else if (result_shape.size() == 2)
{ {
...@@ -286,7 +291,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -286,7 +291,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
}; };
// Parameter is a "runtime no-op" because the output tensor has already been filled. // Parameter is a "runtime no-op" because the output tensor has already been filled.
REGISTER_TO_OP_MAP(op::Parameter) {}; REGISTER_TO_OP_MAP(op::Parameter){};
// GetTupleElement will be spliced out, with the users of out redirected to in's source, but, for now, we need to copy. // GetTupleElement will be spliced out, with the users of out redirected to in's source, but, for now, we need to copy.
REGISTER_TO_OP_MAP(op::GetTupleElement) REGISTER_TO_OP_MAP(op::GetTupleElement)
...@@ -312,7 +317,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -312,7 +317,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
REGISTER_TO_OP_MAP(op::FunctionCall) REGISTER_TO_OP_MAP(op::FunctionCall)
{ {
auto function_call = static_cast<const op::FunctionCall*>(n); auto function_call = static_cast<const op::FunctionCall*>(n);
auto function = function_call->get_function(); auto function = function_call->get_function();
std::shared_ptr<ExternalFunction> external; std::shared_ptr<ExternalFunction> external;
...@@ -322,20 +327,16 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -322,20 +327,16 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
} }
catch (const std::out_of_range) catch (const std::out_of_range)
{ {
external = make_shared<ngraph::runtime::ExternalFunction>( external =
function_call->get_function()); make_shared<ngraph::runtime::ExternalFunction>(function_call->get_function());
function_map.insert({function,external}); function_map.insert({function, external});
} }
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::CallInstruction>(external,in,out)); make_shared<runtime::eigen::CallInstruction>(external, in, out));
};
REGISTER_TO_OP_MAP(op::Reduce)
{
throw ngraph_error("op::Reduce not implemented yet");
}; };
REGISTER_TO_OP_MAP(op::Reduce) { throw ngraph_error("op::Reduce not implemented yet"); };
initialized = true; initialized = true;
} }
return op_map; return op_map;
...@@ -379,8 +380,8 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -379,8 +380,8 @@ void ExternalFunction::compile(FunctionMap& function_map)
{ {
for (const descriptor::Output& output : param->get_outputs()) for (const descriptor::Output& output : param->get_outputs())
{ {
auto tv = output.get_tensor_view(); auto tv = output.get_tensor_view();
size_t index = tensor_index.size(); size_t index = tensor_index.size();
tensor_index[tv] = index; tensor_index[tv] = index;
} }
} }
...@@ -389,8 +390,8 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -389,8 +390,8 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Next are the function outputs // Next are the function outputs
for (const descriptor::Output& output : m_function->get_result()->get_outputs()) for (const descriptor::Output& output : m_function->get_result()->get_outputs())
{ {
auto tv = output.get_tensor_view(); auto tv = output.get_tensor_view();
size_t index = tensor_index.size(); size_t index = tensor_index.size();
tensor_index[tv] = index; tensor_index[tv] = index;
} }
m_n_outputs = tensor_index.size() - m_n_inputs; m_n_outputs = tensor_index.size() - m_n_inputs;
...@@ -403,7 +404,7 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -403,7 +404,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
auto tv = output.get_tensor_view(); auto tv = output.get_tensor_view();
if (0 == tensor_index.count(tv)) if (0 == tensor_index.count(tv))
{ {
size_t index = tensor_index.size(); size_t index = tensor_index.size();
tensor_index[tv] = index; tensor_index[tv] = index;
m_temp_views.push_back(tv); m_temp_views.push_back(tv);
} }
...@@ -423,7 +424,7 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -423,7 +424,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
for (const descriptor::Input& input : node->get_inputs()) for (const descriptor::Input& input : node->get_inputs())
{ {
const descriptor::Output& output = input.get_output(); const descriptor::Output& output = input.get_output();
auto tv = output.get_tensor_view(); auto tv = output.get_tensor_view();
in.push_back({tensor_index.at(tv), tv}); in.push_back({tensor_index.at(tv), tv});
} }
std::vector<TensorViewInfo> out; std::vector<TensorViewInfo> out;
......
...@@ -28,18 +28,19 @@ namespace ngraph ...@@ -28,18 +28,19 @@ namespace ngraph
{ {
class ExternalFunction class ExternalFunction
{ {
using FunctionMap = std::unordered_map<std::shared_ptr<Function>,std::shared_ptr<ExternalFunction>>; using FunctionMap =
std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<ExternalFunction>>;
using OpFunction = std::function<void(const ngraph::Node*, using OpFunction = std::function<void(const ngraph::Node*,
ExternalFunction*, ExternalFunction*,
FunctionMap&, FunctionMap&,
const std::vector<TensorViewInfo>& inputs, const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs)>; const std::vector<TensorViewInfo>& outputs)>;
using OpMap = std::unordered_map<std::type_index, OpFunction>; using OpMap = std::unordered_map<std::type_index, OpFunction>;
public: public:
ExternalFunction(const std::shared_ptr<ngraph::Function>& function, ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true); bool release_function = true);
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame(); std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame();
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame(FunctionMap& function_map); std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame(FunctionMap& function_map);
std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>> std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>
...@@ -50,18 +51,17 @@ namespace ngraph ...@@ -50,18 +51,17 @@ namespace ngraph
// Release original function's resources // Release original function's resources
void release_function() { m_function = nullptr; } void release_function() { m_function = nullptr; }
protected: protected:
void compile(); void compile();
void compile(FunctionMap& function_map); void compile(FunctionMap& function_map);
std::shared_ptr<ngraph::Function> m_function; std::shared_ptr<ngraph::Function> m_function;
bool m_release_function; bool m_release_function;
bool m_is_compiled; bool m_is_compiled;
size_t m_n_inputs; size_t m_n_inputs;
size_t m_n_outputs; size_t m_n_outputs;
std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>> std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>
m_instructions; m_instructions;
ngraph::descriptor::TensorViewPtrs m_temp_views; ngraph::descriptor::TensorViewPtrs m_temp_views;
static OpMap& get_op_map(); static OpMap& get_op_map();
......
...@@ -49,7 +49,7 @@ namespace ngraph ...@@ -49,7 +49,7 @@ namespace ngraph
const std::shared_ptr<ngraph::descriptor::TensorView>& descriptor); const std::shared_ptr<ngraph::descriptor::TensorView>& descriptor);
using element_type = ET; using element_type = ET;
using value_type = typename ET::type; using value_type = typename ET::type;
using storage_type = std::vector<value_type>; using storage_type = std::vector<value_type>;
template <typename T> template <typename T>
...@@ -61,7 +61,6 @@ namespace ngraph ...@@ -61,7 +61,6 @@ namespace ngraph
// For getting the data out // For getting the data out
storage_type& get_vector() { return m_vector; } storage_type& get_vector() { return m_vector; }
protected: protected:
storage_type m_vector; storage_type m_vector;
}; };
......
...@@ -39,9 +39,7 @@ namespace ngraph ...@@ -39,9 +39,7 @@ namespace ngraph
public: public:
TensorView() {} TensorView() {}
virtual ~TensorView() {} virtual ~TensorView() {}
template <typename ET> template <typename ET>
ParameterizedTensorView<ET>* get_parameterized_tensor_view() ParameterizedTensorView<ET>* get_parameterized_tensor_view()
{ {
...@@ -65,7 +63,6 @@ namespace ngraph ...@@ -65,7 +63,6 @@ namespace ngraph
} }
const Shape& get_shape() { return m_descriptor->get_tensor_view_type()->get_shape(); } const Shape& get_shape() { return m_descriptor->get_tensor_view_type()->get_shape(); }
protected: protected:
std::shared_ptr<ngraph::descriptor::TensorView> m_descriptor; std::shared_ptr<ngraph::descriptor::TensorView> m_descriptor;
}; };
......
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
class TensorViewInfo class TensorViewInfo
{ {
public: public:
TensorViewInfo(size_t index, TensorViewInfo(size_t index,
const std::shared_ptr<ngraph::descriptor::TensorView>& descriptor) const std::shared_ptr<ngraph::descriptor::TensorView>& descriptor)
: m_index(index) : m_index(index)
, m_layout(descriptor->get_tensor_view_layout()) , m_layout(descriptor->get_tensor_view_layout())
...@@ -34,7 +34,6 @@ namespace ngraph ...@@ -34,7 +34,6 @@ namespace ngraph
} }
size_t get_index() const { return m_index; } size_t get_index() const { return m_index; }
std::shared_ptr<ngraph::descriptor::layout::TensorViewLayout> std::shared_ptr<ngraph::descriptor::layout::TensorViewLayout>
get_tensor_view_layout() const get_tensor_view_layout() const
{ {
...@@ -48,7 +47,7 @@ namespace ngraph ...@@ -48,7 +47,7 @@ namespace ngraph
} }
protected: protected:
size_t m_index; size_t m_index;
std::shared_ptr<ngraph::descriptor::layout::TensorViewLayout> m_layout; std::shared_ptr<ngraph::descriptor::layout::TensorViewLayout> m_layout;
}; };
} }
......
...@@ -33,7 +33,7 @@ Tuple::Tuple(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& element ...@@ -33,7 +33,7 @@ Tuple::Tuple(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& element
} }
void Tuple::collect_tensor_views(std::vector<std::shared_ptr<TensorView>>& views, void Tuple::collect_tensor_views(std::vector<std::shared_ptr<TensorView>>& views,
const std::shared_ptr<Value>& value) const const std::shared_ptr<Value>& value) const
{ {
for (auto element : m_elements) for (auto element : m_elements)
{ {
......
...@@ -40,12 +40,11 @@ namespace ngraph ...@@ -40,12 +40,11 @@ namespace ngraph
return m_descriptor; return m_descriptor;
} }
virtual void virtual void collect_tensor_views(std::vector<std::shared_ptr<TensorView>>& views,
collect_tensor_views(std::vector<std::shared_ptr<TensorView>>& views, const std::shared_ptr<Value>& value) const override;
const std::shared_ptr<Value>& value) const override;
protected: protected:
std::vector<std::shared_ptr<Value>> m_elements; std::vector<std::shared_ptr<Value>> m_elements;
std::shared_ptr<ngraph::descriptor::Tuple> m_descriptor; std::shared_ptr<ngraph::descriptor::Tuple> m_descriptor;
}; };
} }
......
...@@ -24,17 +24,16 @@ namespace ngraph ...@@ -24,17 +24,16 @@ namespace ngraph
namespace runtime namespace runtime
{ {
class TensorView; class TensorView;
/// @brief A first-class runtime value. /// @brief A first-class runtime value.
class Value class Value
{ {
public: public:
virtual ~Value() {} virtual ~Value() {}
/// @brief The compile-time descriptor for this value. /// @brief The compile-time descriptor for this value.
virtual std::shared_ptr<ngraph::descriptor::Value> get_descriptor() const = 0; virtual std::shared_ptr<ngraph::descriptor::Value> get_descriptor() const = 0;
/// @brief helper for collecting all the tensor views in a sequence of values /// @brief helper for collecting all the tensor views in a sequence of values
/// ///
/// @param views The vector of tensor views being collected. /// @param views The vector of tensor views being collected.
/// @param value A shared pointer for this. /// @param value A shared pointer for this.
......
...@@ -33,7 +33,7 @@ size_t ngraph::shape_size(const Shape& shape) ...@@ -33,7 +33,7 @@ size_t ngraph::shape_size(const Shape& shape)
Strides ngraph::row_major_strides(const Shape& shape) Strides ngraph::row_major_strides(const Shape& shape)
{ {
Strides strides; Strides strides;
size_t s = 1; size_t s = 1;
for (auto d = shape.rbegin(); d != shape.rend(); d++) for (auto d = shape.rbegin(); d != shape.rend(); d++)
{ {
strides.push_back(s); strides.push_back(s);
......
...@@ -16,16 +16,16 @@ ...@@ -16,16 +16,16 @@
#include <cmath> #include <cmath>
#include <iostream> #include <iostream>
#include "ngraph/types/element_type.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph; using namespace ngraph;
std::map<std::string, ngraph::element::Type> ngraph::element::Type::m_element_list; std::map<std::string, ngraph::element::Type> ngraph::element::Type::m_element_list;
ngraph::element::Type::Type(size_t bitwidth, ngraph::element::Type::Type(size_t bitwidth,
bool is_float, bool is_float,
bool is_signed, bool is_signed,
const std::string& cname) const std::string& cname)
: m_bitwidth{bitwidth} : m_bitwidth{bitwidth}
, m_is_float{is_float} , m_is_float{is_float}
......
...@@ -37,23 +37,23 @@ namespace ngraph ...@@ -37,23 +37,23 @@ namespace ngraph
Type(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname); Type(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname);
const std::string& c_type_string() const; const std::string& c_type_string() const;
size_t size() const; size_t size() const;
size_t hash() const size_t hash() const
{ {
std::hash<std::string> h; std::hash<std::string> h;
return h(m_cname); return h(m_cname);
} }
bool operator==(const Type& other) const; bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); } bool operator!=(const Type& other) const { return !(*this == other); }
friend std::ostream& operator<<(std::ostream&, const Type&); friend std::ostream& operator<<(std::ostream&, const Type&);
private: private:
static std::map<std::string, Type> m_element_list; static std::map<std::string, Type> m_element_list;
size_t m_bitwidth; size_t m_bitwidth;
bool m_is_float; bool m_is_float;
bool m_is_signed; bool m_is_signed;
const std::string m_cname; const std::string m_cname;
}; };
std::ostream& operator<<(std::ostream& out, const ngraph::element::Type& obj); std::ostream& operator<<(std::ostream& out, const ngraph::element::Type& obj);
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -39,7 +39,8 @@ bool TensorViewType::operator==(const ValueType& that) const ...@@ -39,7 +39,8 @@ bool TensorViewType::operator==(const ValueType& that) const
return true; return true;
} }
void TensorViewType::collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const void TensorViewType::collect_tensor_views(
std::vector<std::shared_ptr<const TensorViewType>>& views) const
{ {
views.push_back(shared_from_this()); views.push_back(shared_from_this());
} }
...@@ -54,9 +55,10 @@ bool TupleType::operator==(const ValueType& that) const ...@@ -54,9 +55,10 @@ bool TupleType::operator==(const ValueType& that) const
return that_tvt->get_element_types() == get_element_types(); return that_tvt->get_element_types() == get_element_types();
} }
void TupleType::collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const void TupleType::collect_tensor_views(
std::vector<std::shared_ptr<const TensorViewType>>& views) const
{ {
for(auto elt : m_element_types) for (auto elt : m_element_types)
{ {
elt->collect_tensor_views(views); elt->collect_tensor_views(views);
} }
...@@ -70,7 +72,7 @@ std::ostream& ngraph::operator<<(std::ostream& out, const ValueType& obj) ...@@ -70,7 +72,7 @@ std::ostream& ngraph::operator<<(std::ostream& out, const ValueType& obj)
std::ostream& ngraph::operator<<(std::ostream& out, const TensorViewType& obj) std::ostream& ngraph::operator<<(std::ostream& out, const TensorViewType& obj)
{ {
out << "TensorViewType(" << obj.m_element_type << ", {" << join(obj.m_shape) << "})"; out << "TensorViewType(" << obj.m_element_type << ", {" << join(obj.m_shape) << "})";
return out; return out;
} }
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/types/element_type.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/types/element_type.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,12 +35,10 @@ namespace ngraph ...@@ -35,12 +35,10 @@ namespace ngraph
protected: protected:
ValueType() {} ValueType() {}
public: public:
virtual ~ValueType() {} virtual ~ValueType() {}
virtual bool operator==(const ValueType& that) const = 0; virtual bool operator==(const ValueType& that) const = 0;
bool operator!=(const ValueType& that) const { return !(*this == that); } bool operator!=(const ValueType& that) const { return !(*this == that); }
/// Add tensor views in depth-first order. /// Add tensor views in depth-first order.
virtual void collect_tensor_views( virtual void collect_tensor_views(
std::vector<std::shared_ptr<const TensorViewType>>& views) const = 0; std::vector<std::shared_ptr<const TensorViewType>>& views) const = 0;
...@@ -61,8 +59,7 @@ namespace ngraph ...@@ -61,8 +59,7 @@ namespace ngraph
} }
const element::Type& get_element_type() const { return m_element_type; } const element::Type& get_element_type() const { return m_element_type; }
const Shape& get_shape() const { return m_shape; } const Shape& get_shape() const { return m_shape; }
virtual bool operator==(const ValueType& that) const override; virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views( virtual void collect_tensor_views(
std::vector<std::shared_ptr<const TensorViewType>>& views) const override; std::vector<std::shared_ptr<const TensorViewType>>& views) const override;
...@@ -71,7 +68,7 @@ namespace ngraph ...@@ -71,7 +68,7 @@ namespace ngraph
protected: protected:
const element::Type& m_element_type; const element::Type& m_element_type;
Shape m_shape; Shape m_shape;
}; };
/// Describes a tuple of values; a vector of types /// Describes a tuple of values; a vector of types
...@@ -80,7 +77,6 @@ namespace ngraph ...@@ -80,7 +77,6 @@ namespace ngraph
public: public:
/// Construct empty tuple and add value types later. /// Construct empty tuple and add value types later.
TupleType() {} TupleType() {}
/// @param element_types A vector of types for the tuple elements /// @param element_types A vector of types for the tuple elements
TupleType(const std::vector<std::shared_ptr<const ValueType>>& element_types) TupleType(const std::vector<std::shared_ptr<const ValueType>>& element_types)
: m_element_types(element_types) : m_element_types(element_types)
...@@ -91,7 +87,10 @@ namespace ngraph ...@@ -91,7 +87,10 @@ namespace ngraph
{ {
return m_element_types; return m_element_types;
} }
std::vector<std::shared_ptr<const ValueType>> set_element_types() { return m_element_types; } std::vector<std::shared_ptr<const ValueType>> set_element_types()
{
return m_element_types;
}
virtual bool operator==(const ValueType& that) const override; virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views( virtual void collect_tensor_views(
......
...@@ -12,15 +12,15 @@ ...@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <iomanip>
#include <map>
#include <deque> #include <deque>
#include <forward_list> #include <forward_list>
#include <iomanip>
#include <map>
#include <unordered_set> #include <unordered_set>
#include "ngraph/util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -28,10 +28,10 @@ map<string, ngraph::stopwatch*> ngraph::stopwatch_statistics; ...@@ -28,10 +28,10 @@ map<string, ngraph::stopwatch*> ngraph::stopwatch_statistics;
void ngraph::dump(ostream& out, const void* _data, size_t _size) void ngraph::dump(ostream& out, const void* _data, size_t _size)
{ {
auto flags = out.flags(); auto flags = out.flags();
const uint8_t* data = reinterpret_cast<const uint8_t*>(_data); const uint8_t* data = reinterpret_cast<const uint8_t*>(_data);
size_t len = _size; size_t len = _size;
size_t index = 0; size_t index = 0;
while (index < len) while (index < len)
{ {
out << std::hex << std::setw(8) << std::setfill('0') << index; out << std::hex << std::setw(8) << std::setfill('0') << index;
...@@ -99,9 +99,9 @@ string ngraph::trim(const string& s) ...@@ -99,9 +99,9 @@ string ngraph::trim(const string& s)
vector<string> ngraph::split(const string& src, char delimiter, bool do_trim) vector<string> ngraph::split(const string& src, char delimiter, bool do_trim)
{ {
size_t pos; size_t pos;
string token; string token;
size_t start = 0; size_t start = 0;
vector<string> rc; vector<string> rc;
while ((pos = src.find(delimiter, start)) != std::string::npos) while ((pos = src.find(delimiter, start)) != std::string::npos)
{ {
...@@ -135,8 +135,7 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list) ...@@ -135,8 +135,7 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
return seed; return seed;
} }
void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, std::function<void(Node*)> f)
std::function<void(Node*)> f)
{ {
std::unordered_set<Node*> instances_seen; std::unordered_set<Node*> instances_seen;
deque<Node*> stack; deque<Node*> stack;
...@@ -151,7 +150,10 @@ void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, ...@@ -151,7 +150,10 @@ void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p,
f(n); f(n);
} }
stack.pop_front(); stack.pop_front();
for (auto arg : n->get_arguments()) { stack.push_front(arg.get()); } for (auto arg : n->get_arguments())
{
stack.push_front(arg.get());
}
} }
} }
...@@ -159,10 +161,7 @@ void ngraph::free_nodes(shared_ptr<Node> p) ...@@ -159,10 +161,7 @@ void ngraph::free_nodes(shared_ptr<Node> p)
{ {
std::deque<Node*> sorted_list; std::deque<Node*> sorted_list;
traverse_nodes(p, [&](Node* n) traverse_nodes(p, [&](Node* n) { sorted_list.push_front(n); });
{
sorted_list.push_front(n);
});
for (Node* n : sorted_list) for (Node* n : sorted_list)
{ {
......
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
#include <chrono> #include <chrono>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
namespace ngraph namespace ngraph
{ {
...@@ -114,7 +114,7 @@ namespace ngraph ...@@ -114,7 +114,7 @@ namespace ngraph
if (m_active == false) if (m_active == false)
{ {
m_total_count++; m_total_count++;
m_active = true; m_active = true;
m_start_time = m_clock.now(); m_start_time = m_clock.now();
} }
} }
...@@ -124,7 +124,7 @@ namespace ngraph ...@@ -124,7 +124,7 @@ namespace ngraph
if (m_active == true) if (m_active == true)
{ {
auto end_time = m_clock.now(); auto end_time = m_clock.now();
m_last_time = end_time - m_start_time; m_last_time = end_time - m_start_time;
m_total_time += m_last_time; m_total_time += m_last_time;
m_active = false; m_active = false;
} }
...@@ -151,14 +151,14 @@ namespace ngraph ...@@ -151,14 +151,14 @@ namespace ngraph
size_t get_total_microseconds() const { return get_total_nanoseconds() / 1e3; } size_t get_total_microseconds() const { return get_total_nanoseconds() / 1e3; }
size_t get_total_nanoseconds() const { return m_total_time.count(); } size_t get_total_nanoseconds() const { return m_total_time.count(); }
private: private:
std::chrono::high_resolution_clock m_clock; std::chrono::high_resolution_clock m_clock;
std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time; std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time;
bool m_active = false; bool m_active = false;
std::chrono::nanoseconds m_total_time = std::chrono::nanoseconds m_total_time =
std::chrono::high_resolution_clock::duration::zero(); std::chrono::high_resolution_clock::duration::zero();
std::chrono::nanoseconds m_last_time; std::chrono::nanoseconds m_last_time;
size_t m_total_count = 0; size_t m_total_count = 0;
std::string m_name; std::string m_name;
}; };
template <class InputIt, class BinaryOp> template <class InputIt, class BinaryOp>
......
...@@ -32,17 +32,17 @@ class ngraph::uuid_type ...@@ -32,17 +32,17 @@ class ngraph::uuid_type
public: public:
uuid_type() uuid_type()
{ {
m_data[0] = random_generator(); m_data[0] = random_generator();
m_data[1] = random_generator(); m_data[1] = random_generator();
uint8_t* p = (uint8_t*)m_data; uint8_t* p = (uint8_t*)m_data;
p[6] = (p[6] & 0x0F) | 0x40; p[6] = (p[6] & 0x0F) | 0x40;
p[8] = (p[8] & 0x3F) | 0x80; p[8] = (p[8] & 0x3F) | 0x80;
} }
std::string to_string() const std::string to_string() const
{ {
std::stringstream ss; std::stringstream ss;
uint8_t* p = (uint8_t*)m_data; uint8_t* p = (uint8_t*)m_data;
for (int i = 0; i < 4; i++) for (int i = 0; i < 4; i++)
ss << std::hex << std::setw(2) << std::setfill('0') << (int)*p++; ss << std::hex << std::setw(2) << std::setfill('0') << (int)*p++;
ss << "-"; ss << "-";
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#include <list> #include <list>
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/visualize.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "ngraph/visualize.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
...@@ -70,7 +70,7 @@ std::string Visualize::get_attributes(const Node* node) ...@@ -70,7 +70,7 @@ std::string Visualize::get_attributes(const Node* node)
void Visualize::save_dot(const string& path) const void Visualize::save_dot(const string& path) const
{ {
#ifdef GRAPHVIZ_FOUND #ifdef GRAPHVIZ_FOUND
auto tmp_file = path + ".tmp"; auto tmp_file = path + ".tmp";
ofstream out(tmp_file); ofstream out(tmp_file);
if (out) if (out)
{ {
...@@ -81,7 +81,7 @@ void Visualize::save_dot(const string& path) const ...@@ -81,7 +81,7 @@ void Visualize::save_dot(const string& path) const
stringstream ss; stringstream ss;
ss << "dot -Tpng " << tmp_file << " -o " << path; ss << "dot -Tpng " << tmp_file << " -o " << path;
auto cmd = ss.str(); auto cmd = ss.str();
auto stream = popen(cmd.c_str(), "r"); auto stream = popen(cmd.c_str(), "r");
pclose(stream); pclose(stream);
......
...@@ -39,7 +39,7 @@ private: ...@@ -39,7 +39,7 @@ private:
std::string add_attributes(const Node* node); std::string add_attributes(const Node* node);
std::string get_attributes(const Node* node); std::string get_attributes(const Node* node);
std::stringstream m_ss; std::stringstream m_ss;
std::string m_name; std::string m_name;
std::set<const Node*> m_nodes_with_attributes; std::set<const Node*> m_nodes_with_attributes;
}; };
...@@ -23,18 +23,20 @@ using namespace ngraph; ...@@ -23,18 +23,20 @@ using namespace ngraph;
TEST(build_graph, build_simple) TEST(build_graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{7, 3}); auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{7, 3});
auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3}); auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7}); auto arg2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7});
auto arg3 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7}); auto arg3 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7});
auto broadcast_1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0}); auto broadcast_1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0}); auto b1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = make_shared<op::Dot>(arg2, arg0); auto dot = make_shared<op::Dot>(arg2, arg0);
ASSERT_EQ(dot->get_arguments()[0], arg2); ASSERT_EQ(dot->get_arguments()[0], arg2);
ASSERT_EQ(dot->get_arguments()[1], arg0); ASSERT_EQ(dot->get_arguments()[1], arg0);
auto result_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{10,32,7}); auto result_type =
auto cluster_0 = make_shared<Function>(dot, result_type, op::Parameters{arg0, arg1, arg2, arg3}); make_shared<TensorViewType>(element::Float32::element_type(), Shape{10, 32, 7});
auto cluster_0 =
make_shared<Function>(dot, result_type, op::Parameters{arg0, arg1, arg2, arg3});
ASSERT_EQ(cluster_0->get_result(), dot); ASSERT_EQ(cluster_0->get_result(), dot);
} }
...@@ -67,7 +69,7 @@ TEST(build_graph, node_comparison) ...@@ -67,7 +69,7 @@ TEST(build_graph, node_comparison)
auto dot = make_shared<op::Dot>(arg0, arg1); auto dot = make_shared<op::Dot>(arg0, arg1);
auto add = make_shared<op::Add>(dot, arg2); auto add = make_shared<op::Add>(dot, arg2);
auto parg = make_shared<op::Parameter>(element::Float32::element_type(), Shape{}); auto parg = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto pattern_dot = make_shared<op::Dot>(parg, parg); auto pattern_dot = make_shared<op::Dot>(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot)); ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented. // TODO This passes because typeid is not behaving as documented.
...@@ -79,7 +81,7 @@ TEST(build_graph, literal) ...@@ -79,7 +81,7 @@ TEST(build_graph, literal)
{ {
// float scalar from a float // float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0); //auto float0 = FloatScalarConstant::make(3.0);
auto float0 = make_shared<op::Float32ScalarConstant>(3.0); auto float0 = make_shared<op::Float32ScalarConstant>(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{}); auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
ASSERT_EQ(float0->get_value(), 3.0); ASSERT_EQ(float0->get_value(), 3.0);
ASSERT_EQ(*float0->get_value_type(), *float_scalar_type); ASSERT_EQ(*float0->get_value_type(), *float_scalar_type);
...@@ -92,7 +94,7 @@ TEST(build_graph, literal) ...@@ -92,7 +94,7 @@ TEST(build_graph, literal)
ASSERT_EQ(float1->get_value(), 3); ASSERT_EQ(float1->get_value(), 3);
ASSERT_EQ(*float1->get_value_type(), *float_scalar_type); ASSERT_EQ(*float1->get_value_type(), *float_scalar_type);
auto int32_0 = make_shared<op::Int32ScalarConstant>(3.0); auto int32_0 = make_shared<op::Int32ScalarConstant>(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{}); auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->get_value(), 3); ASSERT_EQ(int32_0->get_value(), 3);
ASSERT_EQ(*int32_0->get_value_type(), *int32_scalar_type); ASSERT_EQ(*int32_0->get_value_type(), *int32_scalar_type);
...@@ -182,4 +184,6 @@ TEST(build_graph, set_value_type_checked) ...@@ -182,4 +184,6 @@ TEST(build_graph, set_value_type_checked)
} }
// Check argument inverses // Check argument inverses
TEST(build_graph, arg_inverse) {} TEST(build_graph, arg_inverse)
{
}
This diff is collapsed.
...@@ -66,7 +66,7 @@ TEST(input_output, simple_output) ...@@ -66,7 +66,7 @@ TEST(input_output, simple_output)
auto tv_tp_0 = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}); auto tv_tp_0 = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4});
auto param_0 = make_shared<op::Parameter>(tv_tp_0); auto param_0 = make_shared<op::Parameter>(tv_tp_0);
auto param_1 = make_shared<op::Parameter>(tv_tp_0); auto param_1 = make_shared<op::Parameter>(tv_tp_0);
auto add = make_shared<op::Add>(param_0, param_1); auto add = make_shared<op::Add>(param_0, param_1);
// Sort the ops // Sort the ops
vector<shared_ptr<Node>> nodes; vector<shared_ptr<Node>> nodes;
......
...@@ -22,7 +22,7 @@ using namespace std; ...@@ -22,7 +22,7 @@ using namespace std;
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
const char* exclude = "--gtest_filter=-benchmark.*"; const char* exclude = "--gtest_filter=-benchmark.*";
vector<char*> argv_vector; vector<char*> argv_vector;
argv_vector.push_back(argv[0]); argv_vector.push_back(argv[0]);
argv_vector.push_back((char*)exclude); argv_vector.push_back((char*)exclude);
......
...@@ -13,12 +13,12 @@ ...@@ -13,12 +13,12 @@
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <iostream> #include <iostream>
#include <vector>
#include <mkldnn.hpp> #include <mkldnn.hpp>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
static int tensor_volume(const mkldnn::memory::dims &t) static int tensor_volume(const mkldnn::memory::dims& t)
{ {
int x = 1; int x = 1;
for (const auto i : t) for (const auto i : t)
...@@ -26,7 +26,6 @@ static int tensor_volume(const mkldnn::memory::dims &t) ...@@ -26,7 +26,6 @@ static int tensor_volume(const mkldnn::memory::dims &t)
return x; return x;
} }
TEST(mkldnn, engine) TEST(mkldnn, engine)
{ {
using namespace mkldnn; using namespace mkldnn;
...@@ -34,41 +33,55 @@ TEST(mkldnn, engine) ...@@ -34,41 +33,55 @@ TEST(mkldnn, engine)
#pragma GCC diagnostic ignored "-Wgnu-statement-expression" #pragma GCC diagnostic ignored "-Wgnu-statement-expression"
EXPECT_NO_THROW(({ EXPECT_NO_THROW(({
auto cpu_engine = engine(engine::cpu, 0); auto cpu_engine = engine(engine::cpu, 0);
const int mb = 2; const int mb = 2;
const int groups = 2; const int groups = 2;
memory::dims input_tz = {mb, 256, 13, 13}; memory::dims input_tz = {mb, 256, 13, 13};
memory::dims weights_tz = {groups, 384/groups, 256/groups, 3, 3}; memory::dims weights_tz = {groups, 384 / groups, 256 / groups, 3, 3};
memory::dims bias_tz = {384}; memory::dims bias_tz = {384};
memory::dims strides = {1, 1}; memory::dims strides = {1, 1};
memory::dims padding = {0, 0}; memory::dims padding = {0, 0};
memory::dims output_tz = {mb, 384, memory::dims output_tz = {
(input_tz[2] + 2*padding[0] - weights_tz[3])/strides[0] + 1, mb,
(input_tz[3] + 2*padding[1] - weights_tz[4])/strides[1] + 1, 384,
}; (input_tz[2] + 2 * padding[0] - weights_tz[3]) / strides[0] + 1,
(input_tz[3] + 2 * padding[1] - weights_tz[4]) / strides[1] + 1,
};
std::vector<float> input(tensor_volume(input_tz), .0f); std::vector<float> input(tensor_volume(input_tz), .0f);
std::vector<float> weights(tensor_volume(weights_tz), .0f); std::vector<float> weights(tensor_volume(weights_tz), .0f);
std::vector<float> bias(tensor_volume(bias_tz), .0f); std::vector<float> bias(tensor_volume(bias_tz), .0f);
std::vector<float> output(tensor_volume(output_tz), .0f); std::vector<float> output(tensor_volume(output_tz), .0f);
auto c3_src_desc = memory::desc({input_tz}, memory::data_type::f32, memory::format::nchw); auto c3_src_desc = memory::desc({input_tz}, memory::data_type::f32, memory::format::nchw);
auto c3_weights_desc = memory::desc({weights_tz}, memory::data_type::f32, memory::format::goihw); auto c3_weights_desc =
auto c3_bias_desc = memory::desc({bias_tz}, memory::data_type::f32, memory::format::x); memory::desc({weights_tz}, memory::data_type::f32, memory::format::goihw);
auto c3_dst_desc = memory::desc({output_tz}, memory::data_type::f32, memory::format::nchw); auto c3_bias_desc = memory::desc({bias_tz}, memory::data_type::f32, memory::format::x);
auto c3_dst_desc = memory::desc({output_tz}, memory::data_type::f32, memory::format::nchw);
auto c3_src = memory({c3_src_desc, cpu_engine}, input.data()); auto c3_src = memory({c3_src_desc, cpu_engine}, input.data());
auto c3_weights = memory({c3_weights_desc, cpu_engine}, weights.data()); auto c3_weights = memory({c3_weights_desc, cpu_engine}, weights.data());
auto c3_bias = memory({c3_bias_desc, cpu_engine}, bias.data()); auto c3_bias = memory({c3_bias_desc, cpu_engine}, bias.data());
auto c3_dst = memory({c3_dst_desc, cpu_engine}, output.data()); auto c3_dst = memory({c3_dst_desc, cpu_engine}, output.data());
auto c3 = convolution_forward(convolution_forward::primitive_desc(convolution_forward::desc(prop_kind::forward, auto c3 = convolution_forward(convolution_forward::primitive_desc(
algorithm::convolution_direct, convolution_forward::desc(prop_kind::forward,
c3_src_desc, c3_weights_desc, c3_bias_desc, c3_dst_desc, algorithm::convolution_direct,
strides, padding, padding, padding_kind::zero), c3_src_desc,
cpu_engine), c3_src, c3_weights, c3_bias, c3_dst); c3_weights_desc,
c3_bias_desc,
c3_dst_desc,
strides,
padding,
padding,
padding_kind::zero),
cpu_engine),
c3_src,
c3_weights,
c3_bias,
c3_dst);
stream(stream::kind::eager).submit({c3}).wait(); stream(stream::kind::eager).submit({c3}).wait();
})); }));
} }
...@@ -19,16 +19,16 @@ ...@@ -19,16 +19,16 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/pass/liveness.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp" #include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/log.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
...@@ -49,7 +49,7 @@ TEST(pass, liveness) ...@@ -49,7 +49,7 @@ TEST(pass, liveness)
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::DumpSorted>(dump_file); pass_manager.register_pass<pass::DumpSorted>(dump_file);
shared_ptr<Function> func = make_test_graph(); shared_ptr<Function> func = make_test_graph();
pass_manager.run_passes(func.get()); pass_manager.run_passes(func.get());
auto sorted = pass_manager.get_call_graph(); auto sorted = pass_manager.get_call_graph();
...@@ -81,8 +81,6 @@ TEST(pass, liveness) ...@@ -81,8 +81,6 @@ TEST(pass, liveness)
// auto exc = ex.executor(seq_stuff); // auto exc = ex.executor(seq_stuff);
// return exc; // return exc;
// lg = LivenessGraph(exc.exop.ops) // lg = LivenessGraph(exc.exop.ops)
// lg.layout_memory() // lg.layout_memory()
......
...@@ -37,7 +37,7 @@ TEST(pass_manager, add) ...@@ -37,7 +37,7 @@ TEST(pass_manager, add)
pass_manager.register_pass<pass::PropagateTypes>(); pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>(); pass_manager.register_pass<pass::AssignTensors>();
auto graph = make_test_graph(); auto graph = make_test_graph();
size_t node_count = get_node_count(graph->get_result()); size_t node_count = get_node_count(graph->get_result());
pass_manager.run_passes(graph.get()); pass_manager.run_passes(graph.get());
auto sorted = pass_manager.get_call_graph(); auto sorted = pass_manager.get_call_graph();
......
...@@ -20,15 +20,15 @@ ...@@ -20,15 +20,15 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/propagate_types.hpp" #include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
using namespace ngraph; using namespace ngraph;
......
...@@ -12,20 +12,20 @@ ...@@ -12,20 +12,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <memory>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/function.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp" #include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/function.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
using namespace std; using namespace std;
...@@ -43,9 +43,9 @@ TEST(tensor, size) ...@@ -43,9 +43,9 @@ TEST(tensor, size)
{ {
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3}); auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3});
auto add = make_shared<op::Add>(arg0, arg0); auto add = make_shared<op::Add>(arg0, arg0);
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3}); auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3});
auto f0 = make_shared<Function>(add, rt, op::Parameters{arg0}); auto f0 = make_shared<Function>(add, rt, op::Parameters{arg0});
pass_manager.run_passes(f0); pass_manager.run_passes(f0);
...@@ -57,9 +57,9 @@ TEST(tensor, size) ...@@ -57,9 +57,9 @@ TEST(tensor, size)
{ {
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{}); auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto add = make_shared<op::Add>(arg0, arg0); auto add = make_shared<op::Add>(arg0, arg0);
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{}); auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
auto f0 = make_shared<Function>(add, rt, op::Parameters{arg0}); auto f0 = make_shared<Function>(add, rt, op::Parameters{arg0});
pass_manager.run_passes(f0); pass_manager.run_passes(f0);
...@@ -71,9 +71,9 @@ TEST(tensor, size) ...@@ -71,9 +71,9 @@ TEST(tensor, size)
{ {
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1}); auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto add = make_shared<op::Add>(arg0, arg0); auto add = make_shared<op::Add>(arg0, arg0);
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{1}); auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{1});
auto f0 = make_shared<Function>(add, rt, op::Parameters{arg0}); auto f0 = make_shared<Function>(add, rt, op::Parameters{arg0});
pass_manager.run_passes(f0); pass_manager.run_passes(f0);
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#include <algorithm> #include <algorithm>
#include "test_tools.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "test_tools.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -28,8 +28,8 @@ bool validate_list(const list<Node*>& nodes) ...@@ -28,8 +28,8 @@ bool validate_list(const list<Node*>& nodes)
bool rc = true; bool rc = true;
for (auto it = nodes.rbegin(); it != nodes.rend(); it++) for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
{ {
auto node_tmp = *it; auto node_tmp = *it;
auto dependencies_tmp = node_tmp->get_arguments(); auto dependencies_tmp = node_tmp->get_arguments();
vector<Node*> dependencies; vector<Node*> dependencies;
for (shared_ptr<Node> n : dependencies_tmp) for (shared_ptr<Node> n : dependencies_tmp)
{ {
...@@ -39,7 +39,7 @@ bool validate_list(const list<Node*>& nodes) ...@@ -39,7 +39,7 @@ bool validate_list(const list<Node*>& nodes)
for (; tmp != nodes.rend(); tmp++) for (; tmp != nodes.rend(); tmp++)
{ {
auto dep_tmp = *tmp; auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp); auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
if (found != dependencies.end()) if (found != dependencies.end())
{ {
dependencies.erase(found); dependencies.erase(found);
...@@ -73,7 +73,8 @@ shared_ptr<Function> make_test_graph() ...@@ -73,7 +73,8 @@ shared_ptr<Function> make_test_graph()
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{}); auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
auto f0 = make_shared<Function>(r0, rt, op::Parameters{arg_0, arg_1, arg_2, arg_3, arg_4, arg_5}); auto f0 =
make_shared<Function>(r0, rt, op::Parameters{arg_0, arg_1, arg_2, arg_3, arg_4, arg_5});
return f0; return f0;
} }
...@@ -81,9 +82,6 @@ shared_ptr<Function> make_test_graph() ...@@ -81,9 +82,6 @@ shared_ptr<Function> make_test_graph()
size_t get_node_count(std::shared_ptr<Node> n) size_t get_node_count(std::shared_ptr<Node> n)
{ {
size_t node_count = 0; size_t node_count = 0;
traverse_nodes(n, [&](const Node* node) { traverse_nodes(n, [&](const Node* node) { node_count++; });
node_count++;
});
return node_count; return node_count;
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment