Commit 8c16125d authored by Scott Cyphers's avatar Scott Cyphers

Merge branch 'master' into cyphers/view

parents 0064cfd0 8d57ce68
...@@ -44,3 +44,10 @@ SpacesInSquareBrackets: false ...@@ -44,3 +44,10 @@ SpacesInSquareBrackets: false
SortIncludes: false SortIncludes: false
ReflowComments: true ReflowComments: true
IncludeCategories:
- Regex: '^".*'
Priority: 3
- Regex: '^<.*'
Priority: 2
SortIncludes: true
...@@ -27,7 +27,10 @@ const ngraph::ElementType element_type_uint64_t = ngraph::ElementType(64, false, ...@@ -27,7 +27,10 @@ const ngraph::ElementType element_type_uint64_t = ngraph::ElementType(64, false,
std::map<std::string, ngraph::ElementType> ngraph::ElementType::m_element_list; std::map<std::string, ngraph::ElementType> ngraph::ElementType::m_element_list;
ngraph::ElementType::ElementType(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname) ngraph::ElementType::ElementType(size_t bitwidth,
bool is_float,
bool is_signed,
const std::string& cname)
: m_bitwidth{bitwidth} : m_bitwidth{bitwidth}
, m_is_float{is_float} , m_is_float{is_float}
, m_is_signed{is_signed} , m_is_signed{is_signed}
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#pragma once #pragma once
#include <string>
#include <map> #include <map>
#include <string>
namespace ngraph namespace ngraph
{ {
...@@ -43,10 +43,10 @@ public: ...@@ -43,10 +43,10 @@ public:
private: private:
static std::map<std::string, ElementType> m_element_list; static std::map<std::string, ElementType> m_element_list;
size_t m_bitwidth; size_t m_bitwidth;
bool m_is_float; bool m_is_float;
bool m_is_signed; bool m_is_signed;
const std::string m_cname; const std::string m_cname;
}; };
extern const ngraph::ElementType element_type_float; extern const ngraph::ElementType element_type_float;
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
*/ */
#include <chrono> #include <chrono>
#include <condition_variable>
#include <ctime>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <ctime>
#include <thread>
#include <mutex> #include <mutex>
#include <condition_variable> #include <thread>
#include "log.hpp" #include "log.hpp"
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
#pragma once #pragma once
#include <deque>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <deque>
namespace nervana namespace nervana
{ {
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
using namespace ngraph; using namespace ngraph;
size_t NameableValue::__counter = 0; size_t NameableValue::__counter = 0;
std::map<std::string, NameableValue> NameableValue::__all_names; std::map<std::string, NameableValue> NameableValue::__all_names;
NameableValue::NameableValue(const std::string& name, NameableValue::NameableValue(const std::string& name,
......
...@@ -14,96 +14,94 @@ ...@@ -14,96 +14,94 @@
#pragma once #pragma once
#include <string>
#include <map> #include <map>
#include <string>
namespace ngraph namespace ngraph
{ {
//================================================================================================
// NameableValue
// An Axis labels a dimension of a tensor. The op-graph uses
// the identity of Axis objects to pair and specify dimensions in
// symbolic expressions. This system has several advantages over
// using the length and position of the axis as in other frameworks:
//
// 1) Convenience. The dimensions of tensors, which may be nested
// deep in a computation graph, can be specified without having to
// calculate their lengths.
//
// 2) Safety. Axis labels are analogous to types in general-purpose
// programming languages, allowing objects to interact only when
// they are permitted to do so in advance. In symbolic computation,
// this prevents interference between axes that happen to have the
// same lengths but are logically distinct, e.g. if the number of
// training examples and the number of input features are both 50.
//
// TODO: Please add to the list...
//
// Arguments:
// length: The length of the axis.
// batch: Whether the axis is a batch axis.
// recurrent: Whether the axis is a recurrent axis.
//================================================================================================
class NameableValue
{
public:
//!-----------------------------------------------------------------------------------
//! NameableValue
//! An object that can be named.
//!
//! Arguments:
//! graph_label_type: A label that should be used when drawing the graph. Defaults to
//! the class name.
//! name (str): The name of the object.
//! **kwargs: Parameters for related classes.
//!
//! Attributes:
//! graph_label_type: A label that should be used when drawing the graph.
//! id: Unique id for this object.
//!-----------------------------------------------------------------------------------
NameableValue(const std::string& name,
const std::string& graph_label_type = "",
const std::string& doc_string = "");
//================================================================================================ //!-----------------------------------------------------------------------------------
// NameableValue //! graph_label
// An Axis labels a dimension of a tensor. The op-graph uses //! The label used for drawings of the graph.
// the identity of Axis objects to pair and specify dimensions in //!-----------------------------------------------------------------------------------
// symbolic expressions. This system has several advantages over const std::string& graph_label();
// using the length and position of the axis as in other frameworks:
//
// 1) Convenience. The dimensions of tensors, which may be nested
// deep in a computation graph, can be specified without having to
// calculate their lengths.
//
// 2) Safety. Axis labels are analogous to types in general-purpose
// programming languages, allowing objects to interact only when
// they are permitted to do so in advance. In symbolic computation,
// this prevents interference between axes that happen to have the
// same lengths but are logically distinct, e.g. if the number of
// training examples and the number of input features are both 50.
//
// TODO: Please add to the list...
//
// Arguments:
// length: The length of the axis.
// batch: Whether the axis is a batch axis.
// recurrent: Whether the axis is a recurrent axis.
//================================================================================================
class NameableValue
{
public:
//!-----------------------------------------------------------------------------------
//! NameableValue
//! An object that can be named.
//!
//! Arguments:
//! graph_label_type: A label that should be used when drawing the graph. Defaults to
//! the class name.
//! name (str): The name of the object.
//! **kwargs: Parameters for related classes.
//!
//! Attributes:
//! graph_label_type: A label that should be used when drawing the graph.
//! id: Unique id for this object.
//!-----------------------------------------------------------------------------------
NameableValue(const std::string& name,
const std::string& graph_label_type = "",
const std::string& doc_string = "");
//!----------------------------------------------------------------------------------- //!-----------------------------------------------------------------------------------
//! graph_label //! name
//! The label used for drawings of the graph. //! Sets the object name to a unique name based on name.
//!----------------------------------------------------------------------------------- //!
const std::string& graph_label(); //! Arguments:
//! name: Prefix for the name
//!-----------------------------------------------------------------------------------
const std::string& name();
//!----------------------------------------------------------------------------------- //!-----------------------------------------------------------------------------------
//! name //! name
//! Sets the object name to a unique name based on name. //!-----------------------------------------------------------------------------------
//! void name(const std::string& name);
//! Arguments:
//! name: Prefix for the name
//!-----------------------------------------------------------------------------------
const std::string& name();
//!----------------------------------------------------------------------------------- //!-----------------------------------------------------------------------------------
//! name //! short_name
//!----------------------------------------------------------------------------------- //!-----------------------------------------------------------------------------------
void name(const std::string& name); const std::string& short_name();
//!----------------------------------------------------------------------------------- //!-----------------------------------------------------------------------------------
//! short_name //! named
//!----------------------------------------------------------------------------------- //!-----------------------------------------------------------------------------------
const std::string& short_name(); NameableValue& named(const std::string& name);
//!----------------------------------------------------------------------------------- static size_t __counter;
//! named static std::map<std::string, NameableValue> __all_names;
//!-----------------------------------------------------------------------------------
NameableValue& named(const std::string& name);
static size_t __counter; std::string m_name;
static std::map<std::string, NameableValue> __all_names; std::string m_graph_label;
std::string m_short_name;
std::string m_name; std::string m_doc_string;
std::string m_graph_label; };
std::string m_short_name;
std::string m_doc_string;
};
} // end namespace ngraph } // end namespace ngraph
#include <iostream>
#include <algorithm> #include <algorithm>
#include <iostream>
#include "strides.hpp" #include "strides.hpp"
#include "util.hpp" #include "util.hpp"
......
#pragma once #pragma once
#include <cstdio> #include <cstdio>
#include <vector>
#include <initializer_list> #include <initializer_list>
#include <vector>
#include "element_type.hpp" #include "element_type.hpp"
#include "tree.hpp" #include "tree.hpp"
...@@ -27,10 +27,9 @@ public: ...@@ -27,10 +27,9 @@ public:
ElementType et = element_type_float); ElementType et = element_type_float);
const ElementType& get_type() const { return m_element_type; } const ElementType& get_type() const { return m_element_type; }
tensor_stride full_strides() const;
tensor_stride full_strides() const; tensor_stride strides() const;
tensor_stride strides() const; tensor_size sizes() const;
tensor_size sizes() const;
tensor_size operator[](size_t index) const; tensor_size operator[](size_t index) const;
...@@ -53,9 +52,8 @@ class ngraph::tensor_stride ...@@ -53,9 +52,8 @@ class ngraph::tensor_stride
public: public:
tensor_stride(); tensor_stride();
const ElementType& get_type() const { return m_element_type; } const ElementType& get_type() const { return m_element_type; }
tensor_stride full_strides() const;
tensor_stride full_strides() const; tensor_stride strides() const;
tensor_stride strides() const;
tensor_stride reduce_strides() const; tensor_stride reduce_strides() const;
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <cassert>
#include <cmath>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <cmath>
#include <cassert>
#include "axes.hpp" #include "axes.hpp"
#include "util.hpp" #include "util.hpp"
...@@ -268,7 +268,7 @@ Axis ngraph::slice_axis(const Axis& axis, const slice& s) ...@@ -268,7 +268,7 @@ Axis ngraph::slice_axis(const Axis& axis, const slice& s)
std::vector<std::string> ngraph::duplicates(const std::vector<Axis>& ax) std::vector<std::string> ngraph::duplicates(const std::vector<Axis>& ax)
{ {
std::map<std::string, size_t> counts; std::map<std::string, size_t> counts;
std::vector<std::string> rc; std::vector<std::string> rc;
for (const Axis& axis : ax) for (const Axis& axis : ax)
{ {
auto it = counts.find(axis.name); auto it = counts.find(axis.name);
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <cmath>
#include <exception>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <exception>
#include <cmath>
#include "exop.hpp" #include "exop.hpp"
#include "op_graph.hpp" #include "op_graph.hpp"
...@@ -404,10 +404,10 @@ void ExOpBlock::add_ops(std::initializer_list<computation_op_ptr> roots, exop_pt ...@@ -404,10 +404,10 @@ void ExOpBlock::add_ops(std::initializer_list<computation_op_ptr> roots, exop_pt
} }
} }
std::vector<op_ptr> available; std::vector<op_ptr> available;
std::map<op_ptr, size_t> counts; std::map<op_ptr, size_t> counts;
std::map<op_ptr, std::vector<op_ptr>> parents; std::map<op_ptr, std::vector<op_ptr>> parents;
std::vector<op_ptr> ready; std::vector<op_ptr> ready;
available.insert(available.end(), roots.begin(), roots.end()); available.insert(available.end(), roots.begin(), roots.end());
while (available.size() > 0) while (available.size() > 0)
...@@ -1012,7 +1012,7 @@ tensor_decl_ptr ExecutionState::ensure_tensor_decl(ExecutionGraph& execut ...@@ -1012,7 +1012,7 @@ tensor_decl_ptr ExecutionState::ensure_tensor_decl(ExecutionGraph& execut
bool is_constant = false; bool is_constant = false;
bool is_compile_only = false; bool is_compile_only = false;
tensor_decl = std::make_shared<TensorDecl>(execution_graph, tensor_decl = std::make_shared<TensorDecl>(execution_graph,
tensor_description_base->element_type(), tensor_description_base->element_type(),
tensor_description_base->tensor_size(), tensor_description_base->tensor_size(),
tensor_description_base->is_persistent(), tensor_description_base->is_persistent(),
...@@ -1057,7 +1057,7 @@ tensor_decl_ptr ExecutionGraph::get_tensor_decl(op_ptr op, ...@@ -1057,7 +1057,7 @@ tensor_decl_ptr ExecutionGraph::get_tensor_decl(op_ptr op,
bool is_constant = false; bool is_constant = false;
bool is_compile_only = false; bool is_compile_only = false;
tensor_decl = std::make_shared<TensorDecl>(*this, tensor_decl = std::make_shared<TensorDecl>(*this,
tensor_description_base->element_type(), tensor_description_base->element_type(),
tensor_description_base->tensor_size(), tensor_description_base->tensor_size(),
tensor_description_base->is_persistent(), tensor_description_base->is_persistent(),
......
...@@ -15,443 +15,443 @@ ...@@ -15,443 +15,443 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <string> #include <list>
#include <map> #include <map>
#include <memory> #include <memory>
#include <vector>
#include <sstream>
#include <set> #include <set>
#include <list> #include <sstream>
#include <string>
#include <vector>
#include "axes.hpp"
#include "mock.hpp" #include "mock.hpp"
#include "op_graph.hpp" #include "op_graph.hpp"
#include "axes.hpp"
namespace ngraph namespace ngraph
{ {
// forward declaration. This will hopefully go away
// forward declaration. This will hopefully go away class ExecutionGraph;
class ExecutionGraph; class TensorDescription;
class TensorDescription; class InputDecl;
class InputDecl; class OutputDecl;
class OutputDecl; class TensorDecl;
class TensorDecl; class TensorViewDecl;
class TensorViewDecl; class ExOp;
class ExOp; class Op;
class Op; class ComputationDecl;
class ComputationDecl; class ExOpBlock;
class ExOpBlock; class ExecutionState;
class ExecutionState;
using output_decl_ptr = std::shared_ptr<OutputDecl>;
using output_decl_ptr = std::shared_ptr<OutputDecl>; using input_decl_ptr = std::shared_ptr<InputDecl>;
using input_decl_ptr = std::shared_ptr<InputDecl>; using tensor_decl_ptr = std::shared_ptr<TensorDecl>;
using tensor_decl_ptr = std::shared_ptr<TensorDecl>; using tensor_view_decl_ptr = std::shared_ptr<TensorViewDecl>;
using tensor_view_decl_ptr = std::shared_ptr<TensorViewDecl>; using exop_ptr = std::shared_ptr<ExOp>;
using exop_ptr = std::shared_ptr<ExOp>; using computation_decl_ptr = std::shared_ptr<ComputationDecl>;
using computation_decl_ptr = std::shared_ptr<ComputationDecl>; using execution_graph_ptr = std::shared_ptr<ExecutionGraph>;
using execution_graph_ptr = std::shared_ptr<ExecutionGraph>; using exop_block_ptr = std::shared_ptr<ExOpBlock>;
using exop_block_ptr = std::shared_ptr<ExOpBlock>; using tensor_ptr = std::shared_ptr<TensorInterface>;
using tensor_ptr = std::shared_ptr<TensorInterface>; using transformer_ptr = std::shared_ptr<Transformer>;
using transformer_ptr = std::shared_ptr<Transformer>; using execution_state_ptr = std::shared_ptr<ExecutionState>;
using execution_state_ptr = std::shared_ptr<ExecutionState>;
//================================================================================================
//================================================================================================ // OutputDecl
// OutputDecl // One value computed by an exop
// One value computed by an exop //
// // Arguments:
// Arguments: // exop: The exop.
// exop: The exop. // pos: The position of the value, defaults to 0.
// pos: The position of the value, defaults to 0. // tensor_description: Tensor description of the value.
// tensor_description: Tensor description of the value. // write_view: The tensor view where the value is written.
// write_view: The tensor view where the value is written. //
// // Attributes:
// Attributes: // exop: The exop.
// exop: The exop. // pos: The position of the value.
// pos: The position of the value. // tensor_description: Tensor description of the value.
// tensor_description: Tensor description of the value. // write_view: The tensor view where the value is written.
// write_view: The tensor view where the value is written. // value_users: Arguments using this value.
// value_users: Arguments using this value. //================================================================================================
//================================================================================================
class OutputDecl
class OutputDecl {
{ public:
public: OutputDecl(const ExOp& _exop, size_t _pos, tensor_decl_ptr, tensor_description_ptr);
OutputDecl(const ExOp& _exop, size_t _pos, tensor_decl_ptr, tensor_description_ptr); tensor_decl_ptr tensor_decl();
tensor_decl_ptr tensor_decl(); void tensor_decl(tensor_decl_ptr tensor_decl);
void tensor_decl(tensor_decl_ptr tensor_decl); tensor_view_decl_ptr write_view();
tensor_view_decl_ptr write_view(); void write_view(tensor_view_decl_ptr view);
void write_view(tensor_view_decl_ptr view); friend std::ostream& operator<<(std::ostream& out, const OutputDecl& obj);
friend std::ostream& operator<<(std::ostream& out, const OutputDecl& obj); // def __repr__()
// def __repr__() // {
// { // return "Val({exop}:{pos})".format(exop=self.exop.name, pos=self.pos)
// return "Val({exop}:{pos})".format(exop=self.exop.name, pos=self.pos) // }
// }
bool is_tensor_op() const;
bool is_tensor_op() const;
const ExOp& exop;
const ExOp& exop; size_t pos;
size_t pos; tensor_description_ptr tensor_description;
tensor_description_ptr tensor_description; tensor_decl_ptr __tensor;
tensor_decl_ptr __tensor; tensor_view_decl_ptr __write_view;
tensor_view_decl_ptr __write_view; std::set<InputDecl*> value_users;
std::set<InputDecl*> value_users; };
};
//================================================================================================
//================================================================================================ // InputDecl
// InputDecl // An argument for an exop.
// An argument for an exop. //
// // Arguments:
// Arguments: // exop: The exop.
// exop: The exop. // pos: The position of the value, defaults to 0.
// pos: The position of the value, defaults to 0. // tensor_description: Tensor description of the value.
// tensor_description: Tensor description of the value. // read_view: The tensor view where the value is read from.
// read_view: The tensor view where the value is read from. //
// // Attributes:
// Attributes: // exop: The exop.
// exop: The exop. // pos: The position of the value.
// pos: The position of the value. // tensor_description: Tensor description of the value.
// tensor_description: Tensor description of the value. // read_view: The tensor view where the value is read from.
// read_view: The tensor view where the value is read from. // value: Arguments supplying this value.
// value: Arguments supplying this value. //================================================================================================
//================================================================================================
class InputDecl
class InputDecl
{
public:
InputDecl(const ExOp& _exop,
size_t _pos,
tensor_description_ptr _tensor_description,
OutputDecl* _value);
TensorDecl& tensor_decl();
OutputDecl* value();
const OutputDecl* value() const;
void value(OutputDecl* value);
friend std::ostream& operator<<(std::ostream& out, const InputDecl& obj);
const ExOp& exop;
size_t pos;
tensor_description_ptr tensor_description;
tensor_view_decl_ptr read_view;
OutputDecl* m_value;
};
//================================================================================================
// ExecutionGraphElt
// An element of an exection graph.
//
// Arguments:
// execution_graph: The execution graph that indexes this exop.
//
// Attributes:
// execution_graph: The execution graph that indexes this exop.
//================================================================================================
class ExecutionGraphElt
{
public:
ExecutionGraphElt(ExecutionGraph& eg)
: execution_graph{eg}
{ {
} public:
InputDecl(const ExOp& _exop,
size_t _pos,
tensor_description_ptr _tensor_description,
OutputDecl* _value);
TensorDecl& tensor_decl();
OutputDecl* value();
const OutputDecl* value() const;
void value(OutputDecl* value);
friend std::ostream& operator<<(std::ostream& out, const InputDecl& obj);
const ExOp& exop;
size_t pos;
tensor_description_ptr tensor_description;
tensor_view_decl_ptr read_view;
OutputDecl* m_value;
};
//================================================================================================
// ExecutionGraphElt
// An element of an exection graph.
//
// Arguments:
// execution_graph: The execution graph that indexes this exop.
//
// Attributes:
// execution_graph: The execution graph that indexes this exop.
//================================================================================================
class ExecutionGraphElt
{
public:
ExecutionGraphElt(ExecutionGraph& eg)
: execution_graph{eg}
{
}
ExecutionGraph& execution_graph; ExecutionGraph& execution_graph;
}; };
//================================================================================================ //================================================================================================
// ExOp // ExOp
//================================================================================================ //================================================================================================
class ExOp : public ExecutionGraphElt class ExOp : public ExecutionGraphElt
{ {
public: public:
// An exop that indicates an op to be executed. // An exop that indicates an op to be executed.
// The op might be different from what was originally found in the computation graph. // The op might be different from what was originally found in the computation graph.
// The args are exops that reflect the current version of the graph, and may differ // The args are exops that reflect the current version of the graph, and may differ
// from the exops of the op's args. // from the exops of the op's args.
// The views_in are the current tensor views for the args. // The views_in are the current tensor views for the args.
// The views_out are the current tensor views for any results. // The views_out are the current tensor views for any results.
// Arguments: // Arguments:
// op: The op to execute. // op: The op to execute.
// Parameters: // Parameters:
// op: The computation graph op. // op: The computation graph op.
// views_in: Tensor views of the args. // views_in: Tensor views of the args.
// views_out: Tensor views of the result. // views_out: Tensor views of the result.
// Attributes: // Attributes:
// op: The computation graph op to execute. // op: The computation graph op to execute.
// args: exops for the arguments. // args: exops for the arguments.
// views_in: Views for the arguments. // views_in: Views for the arguments.
// views_out: Views for the results. // views_out: Views for the results.
// tensor: Tensor of the primary output. // tensor: Tensor of the primary output.
// tensor_view: View of the primary output. // tensor_view: View of the primary output.
// ref_ops: All computation graph ops covered by this op // ref_ops: All computation graph ops covered by this op
// op_map: A map from ops to ref ops, sha // op_map: A map from ops to ref ops, sha
ExOp(ComputationDecl& cgraph, op_ptr _op, bool create_value = true); ExOp(ComputationDecl& cgraph, op_ptr _op, bool create_value = true);
friend std::ostream& operator<<(std::ostream& out, const ExOp& obj); friend std::ostream& operator<<(std::ostream& out, const ExOp& obj);
// factory methods to make exops // factory methods to make exops
static exop_ptr literal_scalar_exop(scalar_t scalar, ComputationDecl& computation_graph); static exop_ptr literal_scalar_exop(scalar_t scalar, ComputationDecl& computation_graph);
// A node in the graph, with inputs and outputs. // A node in the graph, with inputs and outputs.
InputDecl& add_arg(OutputDecl& value, tensor_description_ptr tensor_description = nullptr); InputDecl& add_arg(OutputDecl& value, tensor_description_ptr tensor_description = nullptr);
InputDecl& add_write_arg(OutputDecl& value, InputDecl& add_write_arg(OutputDecl& value,
tensor_description_ptr tensor_description = nullptr);
OutputDecl& add_value(tensor_decl_ptr tensor_decl,
tensor_description_ptr tensor_description = nullptr); tensor_description_ptr tensor_description = nullptr);
OutputDecl& add_value(tensor_decl_ptr tensor_decl, op_ptr get_op();
tensor_description_ptr tensor_description = nullptr); void set_op(op_ptr _op);
op_ptr get_op(); void add_ref_op(op_ptr _op);
void set_op(op_ptr _op); size_t memory_usage();
void add_ref_op(op_ptr _op); size_t memory_footprint();
size_t memory_usage(); size_t memory_efficiency();
size_t memory_footprint(); bool is_exop_end_of_list();
size_t memory_efficiency(); std::string name() const;
bool is_exop_end_of_list();
std::string name() const; ComputationDecl& computation_graph;
tensor_decl_ptr tensor_decl;
ComputationDecl& computation_graph; tensor_view_decl_ptr tensor_view;
tensor_decl_ptr tensor_decl; std::vector<op_ptr> ref_ops;
tensor_view_decl_ptr tensor_view; op_ptr op;
std::vector<op_ptr> ref_ops; std::vector<tensor_decl_ptr> liveness_live_list;
op_ptr op; std::vector<tensor_decl_ptr> liveness_free_list;
std::vector<tensor_decl_ptr> liveness_live_list; std::vector<tensor_decl_ptr> liveness_new_list;
std::vector<tensor_decl_ptr> liveness_free_list; std::vector<InputDecl> args;
std::vector<tensor_decl_ptr> liveness_new_list; std::vector<InputDecl*>
std::vector<InputDecl> args; write_args; // TODO: Kludge until we have values with writers/readers
std::vector<InputDecl*> write_args; // TODO: Kludge until we have values with writers/readers std::vector<OutputDecl> values;
std::vector<OutputDecl> values; };
};
//================================================================================================
//================================================================================================ // TensorDecl
// TensorDecl //================================================================================================
//================================================================================================
class TensorDecl : public ExecutionGraphElt
class TensorDecl : public ExecutionGraphElt {
{ public:
public: // Allocate for a tensor.
// Allocate for a tensor.
// Arguments:
// Arguments: // op: The AllocateTensorOp
// op: The AllocateTensorOp // element_type: The type of the elements.
// element_type: The type of the elements. // size: The number of elements.
// size: The number of elements. // is_persistent: True if the tensor is persistent.
// is_persistent: True if the tensor is persistent. // is_input: True if the tensor can be used as an argument.
// is_input: True if the tensor can be used as an argument. // tensor_description_base: The base tensor description for the tensor.
// tensor_description_base: The base tensor description for the tensor. // source_tensor: For a clone, the tensor that started the chain of clones
// source_tensor: For a clone, the tensor that started the chain of clones // this tensor is cloned from.
// this tensor is cloned from.
// Parameters:
// Parameters: // op: The AllocateTensorOp
// op: The AllocateTensorOp // element_type: The type of the elements.
// element_type: The type of the elements. // size: The number of elements.
// size: The number of elements. // is_persistent: True if the tensor is persistent.
// is_persistent: True if the tensor is persistent. // is_input: True if the tensor can be used as an argument.
// is_input: True if the tensor can be used as an argument. // is_output: True if the tensor needs to be available for output. Defaults to is_persistent.
// is_output: True if the tensor needs to be available for output. Defaults to is_persistent. // tensor_descriptions: The set of tensor descriptions for the tensor.
// tensor_descriptions: The set of tensor descriptions for the tensor. // tensor_description_base: The tensor description base for this tensor.
// tensor_description_base: The tensor description base for this tensor. // is_compile_only: If True, this tensor is only needed during compilation, and should not be
// is_compile_only: If True, this tensor is only needed during compilation, and should not be // allocated.
// allocated. TensorDecl(ExecutionGraph&,
TensorDecl(ExecutionGraph&, ElementType,
ElementType, size_t,
size_t, bool _is_persistent,
bool _is_persistent, bool _is_input,
bool _is_input, tensor_description_ptr,
tensor_description_ptr, bool _is_output,
bool _is_output, bool _is_constant,
bool _is_constant, tensor_description_ptr tensor_description,
tensor_description_ptr tensor_description, bool _is_compile_only);
bool _is_compile_only); tensor_view_decl_ptr get_tensor_view(tensor_description_ptr tensor_description = nullptr,
tensor_view_decl_ptr get_tensor_view(tensor_description_ptr tensor_description = nullptr, InputDecl* reader = nullptr,
InputDecl* reader = nullptr, OutputDecl* writer = nullptr);
OutputDecl* writer = nullptr); tensor_view_decl_ptr get_tensor_view(tensor_description_ptr tensor_description = nullptr,
tensor_view_decl_ptr get_tensor_view(tensor_description_ptr tensor_description = nullptr, InputDecl* reader = nullptr);
InputDecl* reader = nullptr); tensor_view_decl_ptr get_tensor_view(tensor_description_ptr tensor_description = nullptr,
tensor_view_decl_ptr get_tensor_view(tensor_description_ptr tensor_description = nullptr, OutputDecl* writer = nullptr);
OutputDecl* writer = nullptr); void merge_flags(const TensorDecl& tensor);
void merge_flags(const TensorDecl& tensor); tensor_description_ptr buffer_key();
tensor_description_ptr buffer_key(); std::string prefix();
std::string prefix(); std::string variable_name();
std::string variable_name(); std::string tensor_name();
std::string tensor_name(); std::string buffer_name();
std::string buffer_name(); // std::string name();
// std::string name(); friend std::ostream& operator<<(std::ostream& out, const TensorDecl& obj);
friend std::ostream& operator<<(std::ostream& out, const TensorDecl& obj);
// op_ptr op;
// op_ptr op; ElementType element_type;
ElementType element_type; size_t size;
size_t size; bool is_persistent;
bool is_persistent; bool is_input;
bool is_input; bool is_output;
bool is_output; size_t buffer_pool_offset;
size_t buffer_pool_offset; std::map<axes_key_t, tensor_view_decl_ptr> tensor_view_decls;
std::map<axes_key_t, tensor_view_decl_ptr> tensor_view_decls; tensor_description_ptr tensor_description_base;
tensor_description_ptr tensor_description_base; size_t lifespan;
size_t lifespan; bool is_constant;
bool is_constant; bool is_compile_only;
bool is_compile_only; tensor_ptr initial_value;
tensor_ptr initial_value; tensor_decl_ptr source_tensor;
tensor_decl_ptr source_tensor; };
};
//================================================================================================
//================================================================================================ // ExOpBlock
// ExOpBlock //================================================================================================
//================================================================================================
class ExOpBlock : public ExecutionGraphElt
class ExOpBlock : public ExecutionGraphElt {
{ public:
public: // Sequentially execute a list of exops.
// Sequentially execute a list of exops.
// Attributes:
// Attributes: // computation_graph: The associated computation graph.
// computation_graph: The associated computation graph. // prev_exop: The latst exop.
// prev_exop: The latst exop. // next_exop: The first exop.
// next_exop: The first exop. // root_set: Set of exops whose values are needed.
// root_set: Set of exops whose values are needed. ExOpBlock(ComputationDecl& cgraph);
ExOpBlock(ComputationDecl& cgraph); bool is_exop_end_of_list();
bool is_exop_end_of_list(); void add_ops(std::initializer_list<computation_op_ptr> roots,
void add_ops(std::initializer_list<computation_op_ptr> roots, exop_ptr after_exop = nullptr); exop_ptr after_exop = nullptr);
exop_ptr add_op(op_ptr op, exop_ptr after_exop); exop_ptr add_op(op_ptr op, exop_ptr after_exop);
exop_ptr add_exop(exop_ptr exop, exop_ptr after_exop = nullptr); exop_ptr add_exop(exop_ptr exop, exop_ptr after_exop = nullptr);
void move_exop_to_after_exop(exop_ptr exop, exop_ptr after_exop); void move_exop_to_after_exop(exop_ptr exop, exop_ptr after_exop);
void remove_exop(exop_ptr exop); void remove_exop(exop_ptr exop);
void replace_op(op_ptr old_op, op_ptr new_op); void replace_op(op_ptr old_op, op_ptr new_op);
void replace_users(exop_ptr old_exop, exop_ptr new_exop); void replace_users(exop_ptr old_exop, exop_ptr new_exop);
void replace_value(OutputDecl* old_value, OutputDecl* new_value); void replace_value(OutputDecl* old_value, OutputDecl* new_value);
void replace_exop(exop_ptr old_exop, exop_ptr new_exop); void replace_exop(exop_ptr old_exop, exop_ptr new_exop);
void merge_exop(exop_ptr old_exop, exop_ptr new_exop); void merge_exop(exop_ptr old_exop, exop_ptr new_exop);
size_t memory_footprint(); size_t memory_footprint();
size_t worst_case_footprint(); size_t worst_case_footprint();
size_t memory_efficiency(); size_t memory_efficiency();
size_t persistent_size(); size_t persistent_size();
std::set<OutputDecl*> get_vars(); std::set<OutputDecl*> get_vars();
std::set<OutputDecl*> get_temp_vars(); std::set<OutputDecl*> get_temp_vars();
std::set<OutputDecl*> get_persistent_vars(); std::set<OutputDecl*> get_persistent_vars();
ComputationDecl& computation_graph; ComputationDecl& computation_graph;
std::set<ExOp*> root_set; std::set<ExOp*> root_set;
// replacement for next_exop, prev_exop // replacement for next_exop, prev_exop
std::list<exop_ptr>::iterator begin() { return op_list.begin(); } std::list<exop_ptr>::iterator begin() { return op_list.begin(); }
std::list<exop_ptr>::iterator end() { return op_list.end(); } std::list<exop_ptr>::iterator end() { return op_list.end(); }
std::list<exop_ptr> op_list;
std::list<exop_ptr> op_list; };
};
//================================================================================================
//================================================================================================ // TensorViewDecl
// TensorViewDecl //================================================================================================
//================================================================================================
class TensorViewDecl : public ExecutionGraphElt
class TensorViewDecl : public ExecutionGraphElt {
{ public:
public: // Declare a view of a tensor.
// Declare a view of a tensor.
// Arguments:
// Arguments: // tensor: The tensor.
// tensor: The tensor. // tensor_description: The description of the view.
// tensor_description: The description of the view. TensorViewDecl(TensorDecl&, tensor_description_ptr, ExecutionGraph&);
TensorViewDecl(TensorDecl&, tensor_description_ptr, ExecutionGraph&); std::string name() const;
std::string name() const; // op_ptr op();
// op_ptr op(); tensor_view_decl_ptr get_tensor_view(tensor_description_ptr, InputDecl*, OutputDecl*);
tensor_view_decl_ptr get_tensor_view(tensor_description_ptr, InputDecl*, OutputDecl*); tensor_view_decl_ptr get_tensor_view(tensor_description_ptr, InputDecl*);
tensor_view_decl_ptr get_tensor_view(tensor_description_ptr, InputDecl*); tensor_view_decl_ptr get_tensor_view(tensor_description_ptr, OutputDecl*);
tensor_view_decl_ptr get_tensor_view(tensor_description_ptr, OutputDecl*);
// def key()
// def key() // {
// { // """
// """ // // Returns: A tuple unique to this view of the tensor.
// // Returns: A tuple unique to this view of the tensor. // """
// """ // return tensor_description->parameter_key
// return tensor_description->parameter_key // }
// }
TensorDecl& tensor_decl;
TensorDecl& tensor_decl; tensor_description_ptr tensor_description;
tensor_description_ptr tensor_description; // initializers;
// initializers; std::set<InputDecl*> readers;
std::set<InputDecl*> readers; std::set<OutputDecl*> writers;
std::set<OutputDecl*> writers; OutputDecl* value;
OutputDecl* value; };
};
// static exop_ptr _default_default;
// static exop_ptr _default_default;
//================================================================================================
//================================================================================================ // ComputationDecl
// ComputationDecl //================================================================================================
//================================================================================================
class ComputationDecl : public ExecutionGraphElt
class ComputationDecl : public ExecutionGraphElt {
{ public:
public: // One computation to be run.
// One computation to be run.
// Every computation has its own execution graph. Persistent tensors are shared
// Every computation has its own execution graph. Persistent tensors are shared // between computations, other tensors are not.
// between computations, other tensors are not.
// Attributes:
// Attributes: // computation: The computation op.
// computation: The computation op. // ops: A map from ops to the exop that handles the op in this computation.
// ops: A map from ops to the exop that handles the op in this computation. // exop: The SSA block of exops for this computation.
// exop: The SSA block of exops for this computation. // values: The ops whose values are returned from the computation.
// values: The ops whose values are returned from the computation. // tensors: Map from base tensor descriptions to tensors.
// tensors: Map from base tensor descriptions to tensors. ComputationDecl(ExecutionGraph& eg, computation_op_ptr op);
ComputationDecl(ExecutionGraph& eg, computation_op_ptr op); tensor_decl_ptr get_tensor_decl(op_ptr _op = nullptr);
tensor_decl_ptr get_tensor_decl(op_ptr _op = nullptr); ExOp* get_exop(op_ptr _op);
ExOp* get_exop(op_ptr _op);
computation_op_ptr computation_op;
computation_op_ptr computation_op; std::map<op_ptr, ExOp*> ops;
std::map<op_ptr, ExOp*> ops; std::vector<tensor_decl_ptr> tensors;
std::vector<tensor_decl_ptr> tensors; std::map<Op*, InputDecl*> op_returns; // op_returns_anchor?
std::map<Op*, InputDecl*> op_returns; // op_returns_anchor? exop_block_ptr exop_block;
exop_block_ptr exop_block; exop_ptr returns;
exop_ptr returns; std::set<ExOp*> values;
std::set<ExOp*> values; };
};
//================================================================================================
//================================================================================================ // ExecutionState
// ExecutionState //================================================================================================
//================================================================================================
class ExecutionState
class ExecutionState {
{ public:
public: // Proxy for the state of a device.
// Proxy for the state of a device.
// Arguments: // Arguments:
// transformer: The associated transformer. // transformer: The associated transformer.
ExecutionState(transformer_ptr transformer = nullptr); ExecutionState(transformer_ptr transformer = nullptr);
transformer_ptr transformer(); transformer_ptr transformer();
execution_graph_ptr make_execution_graph(computation_op_ptr); execution_graph_ptr make_execution_graph(computation_op_ptr);
tensor_decl_ptr get_op_tensor(op_ptr op); tensor_decl_ptr get_op_tensor(op_ptr op);
tensor_decl_ptr ensure_tensor_decl(ExecutionGraph&, tensor_description_ptr, op_ptr); tensor_decl_ptr ensure_tensor_decl(ExecutionGraph&, tensor_description_ptr, op_ptr);
transformer_ptr __transformer; transformer_ptr __transformer;
// persistent tensors // persistent tensors
std::map<tensor_description_ptr, tensor_decl_ptr> __tensors_decls; std::map<tensor_description_ptr, tensor_decl_ptr> __tensors_decls;
}; };
//================================================================================================ //================================================================================================
// ExecutionGraph // ExecutionGraph
//================================================================================================ //================================================================================================
class ExecutionGraph class ExecutionGraph
{ {
public: public:
// Information for compiling a computation_op. // Information for compiling a computation_op.
// Arguments: // Arguments:
// execution_state: The execution state the graph will be applied to. The definitons in // execution_state: The execution state the graph will be applied to. The definitons in
// the execution state can be used in the execution graph. // the execution state can be used in the execution graph.
// computation_op: A computation to be processed // computation_op: A computation to be processed
ExecutionGraph(ExecutionState& execution_state, computation_op_ptr computation_op); ExecutionGraph(ExecutionState& execution_state, computation_op_ptr computation_op);
tensor_decl_ptr get_tensor_decl(op_ptr, tensor_description_ptr = nullptr); tensor_decl_ptr get_tensor_decl(op_ptr, tensor_description_ptr = nullptr);
ExecutionState& execution_state; ExecutionState& execution_state;
// temporary tensors // temporary tensors
std::map<tensor_description_ptr, tensor_decl_ptr> tensor_decls; std::map<tensor_description_ptr, tensor_decl_ptr> tensor_decls;
computation_decl_ptr computation_decl; computation_decl_ptr computation_decl;
}; };
} // end namespace ngraph } // end namespace ngraph
...@@ -14,169 +14,165 @@ ...@@ -14,169 +14,165 @@
#pragma once #pragma once
#include <string>
#include <memory>
#include <map> #include <map>
#include <vector> #include <memory>
#include <type_traits>
#include <sstream> #include <sstream>
#include <string>
#include <type_traits>
#include <vector>
#include "element_type.hpp" #include "element_type.hpp"
namespace ngraph namespace ngraph
{ {
class ExecutionState;
class ExecutionState; class Op;
// class TensorDescription;
class Op; class ComputationOp;
// class TensorDescription;
class ComputationOp;
using computation_op_ptr = std::shared_ptr<ComputationOp>;
using op_ptr = std::shared_ptr<Op>;
using scalar_t = float;
//================================================================================================
// TensorInterface
//================================================================================================
class TensorInterface using computation_op_ptr = std::shared_ptr<ComputationOp>;
{ using op_ptr = std::shared_ptr<Op>;
public: using scalar_t = float;
virtual ~TensorInterface() {}
virtual const ElementType& element_type() const = 0;
virtual std::string value_string() const = 0;
};
//================================================================================================ //================================================================================================
// Tensor // TensorInterface
//================================================================================================ //================================================================================================
template <typename T> class TensorInterface
class Tensor : public TensorInterface
{
public:
Tensor(const T& val)
: m_value{val}
, m_element_type{element_type_float}
{ {
} public:
virtual ~TensorInterface() {}
virtual ~Tensor() {} virtual const ElementType& element_type() const = 0;
virtual std::string value_string() const = 0;
const ElementType& element_type() const override { return m_element_type; } };
std::string value_string() const override //================================================================================================
// Tensor
//================================================================================================
template <typename T>
class Tensor : public TensorInterface
{ {
std::string rc = "WTF"; public:
if (std::is_floating_point<T>::value) Tensor(const T& val)
: m_value{val}
, m_element_type{element_type_float}
{ {
std::stringstream ss;
ss << m_value;
rc = ss.str();
} }
return rc;
}
private: virtual ~Tensor() {}
T m_value; const ElementType& element_type() const override { return m_element_type; }
ElementType m_element_type; std::string value_string() const override
}; {
std::string rc = "WTF";
if (std::is_floating_point<T>::value)
{
std::stringstream ss;
ss << m_value;
rc = ss.str();
}
return rc;
}
//================================================================================================ private:
// Transformer T m_value;
//================================================================================================ ElementType m_element_type;
};
class Transformer //================================================================================================
{ // Transformer
public: //================================================================================================
virtual ~Transformer() {}
virtual ExecutionState& execution_state() = 0; class Transformer
}; {
public:
//================================================================================================ virtual ~Transformer() {}
// TensorDescription virtual ExecutionState& execution_state() = 0;
//================================================================================================ };
// class TensorDescription //================================================================================================
// { // TensorDescription
// public: //================================================================================================
// virtual ~TensorDescription();
// virtual axes_key_t axes_key() const = 0; // class TensorDescription
// virtual std::string name() const = 0; // {
// virtual std::vector<size_t> shape() const = 0; // public:
// virtual std::shared_ptr<TensorDescription> base() = 0; // virtual ~TensorDescription();
// virtual ElementType element_type() const = 0; // virtual axes_key_t axes_key() const = 0;
// virtual size_t tensor_size() = 0; // virtual std::string name() const = 0;
// virtual bool is_persistent() = 0; // virtual std::vector<size_t> shape() const = 0;
// virtual bool is_input() = 0; // virtual std::shared_ptr<TensorDescription> base() = 0;
// }; // virtual ElementType element_type() const = 0;
// virtual size_t tensor_size() = 0;
//================================================================================================ // virtual bool is_persistent() = 0;
// Op // virtual bool is_input() = 0;
//================================================================================================ // };
// class Op //================================================================================================
// { // Op
// // Any operation that can be in an AST. //================================================================================================
// // Arguments: // class Op
// // args: Values used by this node. // {
// // const: The value of a constant Op, or None, // // Any operation that can be in an AST.
// // constant (bool): The Op is constant. Default False.
// // forward: If not None, the node to use instead of this node. // // Arguments:
// // metadata: String key value dictionary for frontend metadata. // // args: Values used by this node.
// // kwargs: Args defined in related classes. // // const: The value of a constant Op, or None,
// // constant (bool): The Op is constant. Default False.
// // Attributes: // // forward: If not None, the node to use instead of this node.
// // const: The value of a constant. // // metadata: String key value dictionary for frontend metadata.
// // constant (bool): The value is constant. // // kwargs: Args defined in related classes.
// // control_deps (OrderedSet): Ops in addtion to args that must run before this op.
// // persistent (bool): The value will be retained from computation to computation and // // Attributes:
// // not shared. Always True if reference is set. // // const: The value of a constant.
// // metadata: Dictionary with of string keys and values used for attaching // // constant (bool): The value is constant.
// // arbitrary metadata to nodes. // // control_deps (OrderedSet): Ops in addtion to args that must run before this op.
// // trainable: The value is trainable. // // persistent (bool): The value will be retained from computation to computation and
// public: // // not shared. Always True if reference is set.
// virtual ~Op() {} // // metadata: Dictionary with of string keys and values used for attaching
// // arbitrary metadata to nodes.
// virtual std::string name() const = 0; // // trainable: The value is trainable.
// virtual tensor_description_ptr tensor_description() = 0; // public:
// virtual op_ptr tensor() = 0; // virtual ~Op() {}
// virtual bool is_tensor_op() = 0; // virtual std::string name() const = 0;
// virtual bool is_state_op() const = 0; // virtual tensor_description_ptr tensor_description() = 0;
// virtual bool is_sequencing_op() const = 0; // virtual op_ptr tensor() = 0;
// virtual op_ptr effective_tensor_op() = 0;
// virtual const std::vector<op_ptr>& all_deps() const = 0; // virtual bool is_tensor_op() = 0;
// virtual bool is_state_op() const = 0;
// // ops // virtual bool is_sequencing_op() const = 0;
// virtual op_ptr effective_tensor_op() = 0;
// // TODO support multiple types // virtual const std::vector<op_ptr>& all_deps() const = 0;
// static op_ptr constant(float value)
// { // // ops
// op_ptr = make_shared<LiteralScalarOp>(value);
// } // // TODO support multiple types
// }; // static op_ptr constant(float value)
// {
//================================================================================================ // op_ptr = make_shared<LiteralScalarOp>(value);
// TensorOp // }
//================================================================================================ // };
// class TensorOp : public Op //================================================================================================
// { // TensorOp
// public: //================================================================================================
// std::string name() const override { return "TensorOp"; }
// tensor_description_ptr tensor_description() override { return nullptr; } // class TensorOp : public Op
// op_ptr tensor() override { return nullptr; } // {
// bool is_tensor_op() override { return true; } // public:
// bool is_state_op() const override { return false; } // std::string name() const override { return "TensorOp"; }
// op_ptr effective_tensor_op() override { return nullptr; } // tensor_description_ptr tensor_description() override { return nullptr; }
// const std::vector<op_ptr>& all_deps() const override { return m_all_deps; } // op_ptr tensor() override { return nullptr; }
// bool is_tensor_op() override { return true; }
// private: // bool is_state_op() const override { return false; }
// std::vector<op_ptr> m_all_deps; // op_ptr effective_tensor_op() override { return nullptr; }
// }; // const std::vector<op_ptr>& all_deps() const override { return m_all_deps; }
// private:
// std::vector<op_ptr> m_all_deps;
// };
} // end of namespace ngraph } // end of namespace ngraph
...@@ -14,24 +14,21 @@ ...@@ -14,24 +14,21 @@
#pragma once #pragma once
#include "mock.hpp"
#include "exop.hpp" #include "exop.hpp"
#include "mock.hpp"
namespace ngraph namespace ngraph
{ {
//================================================================================================
//================================================================================================ // CpuTransformer
// CpuTransformer //================================================================================================
//================================================================================================ class CpuTransformer : public Transformer
class CpuTransformer : public Transformer {
{ public:
public: virtual ~CpuTransformer() {}
virtual ~CpuTransformer() {} ExecutionState& execution_state() override { return m_execution_state; }
private:
ExecutionState& execution_state() override { return m_execution_state; } ExecutionState m_execution_state;
};
private:
ExecutionState m_execution_state;
};
} // end namespace ngraph } // end namespace ngraph
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#pragma once #pragma once
#include <vector>
#include <memory> #include <memory>
#include <vector>
#include "element_type.hpp" #include "element_type.hpp"
#include "strides.hpp" #include "strides.hpp"
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include <sstream> #include <sstream>
#include "op_graph.hpp"
#include "axes.hpp" #include "axes.hpp"
#include "op_graph.hpp"
#include "util.hpp" #include "util.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -2794,7 +2794,9 @@ ElementWiseOp::ElementWiseOp() ...@@ -2794,7 +2794,9 @@ ElementWiseOp::ElementWiseOp()
{ {
} }
void ElementWiseOp::ElementWiseOp_init(std::vector<op_ptr>, Axes) {} void ElementWiseOp::ElementWiseOp_init(std::vector<op_ptr>, Axes)
{
}
//================================================================================================ //================================================================================================
// UnaryElementWiseOp // UnaryElementWiseOp
......
This source diff could not be displayed because it is too large. You can view the blob instead.
#pragma once #pragma once
#include <algorithm>
#include <functional> #include <functional>
#include <vector>
#include <initializer_list> #include <initializer_list>
#include <iostream> #include <iostream>
#include <algorithm> #include <vector>
#include "util.hpp" #include "util.hpp"
...@@ -51,7 +51,6 @@ public: ...@@ -51,7 +51,6 @@ public:
bool is_list() const { return m_is_list; } bool is_list() const { return m_is_list; }
T get_value() const { return m_value; } T get_value() const { return m_value; }
const std::vector<tree>& get_list() const { return m_list; } const std::vector<tree>& get_list() const { return m_list; }
static void traverse_tree(tree& s, std::function<void(T*)> func) static void traverse_tree(tree& s, std::function<void(T*)> func)
{ {
if (s.is_list()) if (s.is_list())
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <map>
#include <iomanip> #include <iomanip>
#include <map>
#include "util.hpp" #include "util.hpp"
......
...@@ -14,184 +14,183 @@ ...@@ -14,184 +14,183 @@
#pragma once #pragma once
#include <string>
#include <sstream>
#include <vector>
#include <chrono>
#include <algorithm> #include <algorithm>
#include <map> #include <chrono>
#include <iostream> #include <iostream>
#include <map>
#include <sstream>
#include <string>
#include <vector>
namespace ngraph namespace ngraph
{ {
class stopwatch;
extern std::map<std::string, stopwatch*> stopwatch_statistics;
class stopwatch; template <typename T>
extern std::map<std::string, stopwatch*> stopwatch_statistics; std::string join(const T& v, const std::string& sep)
template <typename T>
std::string join(const T& v, const std::string& sep)
{
std::ostringstream ss;
for (const auto& x : v)
{ {
if (&x != &*(v.begin())) std::ostringstream ss;
for (const auto& x : v)
{ {
ss << sep; if (&x != &*(v.begin()))
{
ss << sep;
}
ss << x;
} }
ss << x; return ss.str();
} }
return ss.str();
}
template <typename U, typename T> template <typename U, typename T>
bool contains(const U& container, const T& obj) bool contains(const U& container, const T& obj)
{
bool rc = false;
for (auto o : container)
{ {
if (o == obj) bool rc = false;
for (auto o : container)
{ {
rc = true; if (o == obj)
break; {
rc = true;
break;
}
} }
return rc;
} }
return rc;
}
template <typename U, typename T> template <typename U, typename T>
bool contains_key(const U& container, const T& obj) bool contains_key(const U& container, const T& obj)
{
bool rc = false;
for (auto o : container)
{ {
if (o.first == obj) bool rc = false;
for (auto o : container)
{ {
rc = true; if (o.first == obj)
break; {
rc = true;
break;
}
} }
return rc;
} }
return rc;
}
template <typename U, typename T> template <typename U, typename T>
void remove_from(U& container, const T& obj) void remove_from(U& container, const T& obj)
{
auto it = container.find(obj);
if (it != container.end())
{ {
container.erase(it); auto it = container.find(obj);
if (it != container.end())
{
container.erase(it);
}
} }
}
size_t hash_combine(const std::vector<size_t>& list); size_t hash_combine(const std::vector<size_t>& list);
void dump(std::ostream& out, const void*, size_t); void dump(std::ostream& out, const void*, size_t);
std::string to_lower(const std::string& s); std::string to_lower(const std::string& s);
std::string trim(const std::string& s); std::string trim(const std::string& s);
std::vector<std::string> split(const std::string& s, char delimiter, bool trim = false); std::vector<std::string> split(const std::string& s, char delimiter, bool trim = false);
class stopwatch class stopwatch
{
public:
stopwatch() {}
stopwatch(const std::string& name)
: m_name{name}
{ {
stopwatch_statistics.insert({m_name, this}); public:
} stopwatch() {}
stopwatch(const std::string& name)
: m_name{name}
{
stopwatch_statistics.insert({m_name, this});
}
~stopwatch() ~stopwatch()
{
if (m_name.size() > 0)
{ {
stopwatch_statistics.find(m_name); if (m_name.size() > 0)
{
stopwatch_statistics.find(m_name);
}
} }
}
void start() void start()
{
if (m_active == false)
{ {
m_total_count++; if (m_active == false)
m_active = true; {
m_start_time = m_clock.now(); m_total_count++;
m_active = true;
m_start_time = m_clock.now();
}
} }
}
void stop() void stop()
{
if (m_active == true)
{ {
auto end_time = m_clock.now(); if (m_active == true)
m_last_time = end_time - m_start_time; {
m_total_time += m_last_time; auto end_time = m_clock.now();
m_active = false; m_last_time = end_time - m_start_time;
m_total_time += m_last_time;
m_active = false;
}
}
size_t get_call_count() const { return m_total_count; }
size_t get_seconds() const { return get_nanoseconds() / 1e9; }
size_t get_milliseconds() const { return get_nanoseconds() / 1e6; }
size_t get_microseconds() const { return get_nanoseconds() / 1e3; }
size_t get_nanoseconds() const
{
if (m_active)
{
return (m_clock.now() - m_start_time).count();
}
else
{
return m_last_time.count();
}
} }
}
size_t get_call_count() const { return m_total_count; } size_t get_total_seconds() const { return get_total_nanoseconds() / 1e9; }
size_t get_seconds() const { return get_nanoseconds() / 1e9; } size_t get_total_milliseconds() const { return get_total_nanoseconds() / 1e6; }
size_t get_milliseconds() const { return get_nanoseconds() / 1e6; } size_t get_total_microseconds() const { return get_total_nanoseconds() / 1e3; }
size_t get_microseconds() const { return get_nanoseconds() / 1e3; } size_t get_total_nanoseconds() const { return m_total_time.count(); }
size_t get_nanoseconds() const private:
std::chrono::high_resolution_clock m_clock;
std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time;
bool m_active = false;
std::chrono::nanoseconds m_total_time =
std::chrono::high_resolution_clock::duration::zero();
std::chrono::nanoseconds m_last_time;
size_t m_total_count = 0;
std::string m_name;
};
template <class InputIt, class BinaryOp>
typename std::iterator_traits<InputIt>::value_type
reduce(InputIt first, InputIt last, BinaryOp op)
{ {
if (m_active) typename std::iterator_traits<InputIt>::value_type result;
if (first == last)
{ {
return (m_clock.now() - m_start_time).count(); result = {};
} }
else else
{ {
return m_last_time.count(); result = *first++;
while (first != last)
{
result = op(result, *first);
first++;
}
} }
return result;
} }
size_t get_total_seconds() const { return get_total_nanoseconds() / 1e9; } template <typename T>
size_t get_total_milliseconds() const { return get_total_nanoseconds() / 1e6; } T plus(const T& a, const T& b)
size_t get_total_microseconds() const { return get_total_nanoseconds() / 1e3; }
size_t get_total_nanoseconds() const { return m_total_time.count(); }
private:
std::chrono::high_resolution_clock m_clock;
std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time;
bool m_active = false;
std::chrono::nanoseconds m_total_time = std::chrono::high_resolution_clock::duration::zero();
std::chrono::nanoseconds m_last_time;
size_t m_total_count = 0;
std::string m_name;
};
template <class InputIt, class BinaryOp>
typename std::iterator_traits<InputIt>::value_type
reduce(InputIt first, InputIt last, BinaryOp op)
{
typename std::iterator_traits<InputIt>::value_type result;
if (first == last)
{ {
result = {}; return a + b;
} }
else
template <typename T>
T mul(const T& a, const T& b)
{ {
result = *first++; return a * b;
while (first != last)
{
result = op(result, *first);
first++;
}
} }
return result;
}
template <typename T>
T plus(const T& a, const T& b)
{
return a + b;
}
template <typename T>
T mul(const T& a, const T& b)
{
return a * b;
}
} // end namespace ngraph } // end namespace ngraph
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
#pragma once #pragma once
#include <array> #include <array>
#include <random> #include <cstring>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <cstring> #include <random>
static std::mt19937_64 random_generator; static std::mt19937_64 random_generator;
...@@ -74,7 +74,6 @@ public: ...@@ -74,7 +74,6 @@ public:
} }
bool operator!=(const uuid_type& other) const { return !(*this == other); } bool operator!=(const uuid_type& other) const { return !(*this == other); }
friend std::ostream& operator<<(std::ostream& out, const uuid_type& id) friend std::ostream& operator<<(std::ostream& out, const uuid_type& id)
{ {
out << id.to_string(); out << id.to_string();
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -310,7 +310,7 @@ TEST(axes, index) ...@@ -310,7 +310,7 @@ TEST(axes, index)
EXPECT_EQ(7, b[1].length()); EXPECT_EQ(7, b[1].length());
} }
TEST(axes, as_nested_list) TEST(axes, DISABLED_as_nested_list)
{ {
Axis C = make_axis(5); Axis C = make_axis(5);
Axis H = make_axis(3); Axis H = make_axis(3);
...@@ -325,7 +325,7 @@ TEST(axes, as_nested_list) ...@@ -325,7 +325,7 @@ TEST(axes, as_nested_list)
FAIL(); FAIL();
} }
TEST(axes, flatten) TEST(axes, DISABLED_flatten)
{ {
Axis C = make_axis(5); Axis C = make_axis(5);
Axis H = make_axis(3); Axis H = make_axis(3);
...@@ -336,7 +336,7 @@ TEST(axes, flatten) ...@@ -336,7 +336,7 @@ TEST(axes, flatten)
EXPECT_TRUE(c.is_flattened()); EXPECT_TRUE(c.is_flattened());
} }
TEST(axes, as_flattened_list) TEST(axes, DISABLED_as_flattened_list)
{ {
FAIL(); FAIL();
} }
...@@ -364,7 +364,7 @@ TEST(axes, hash_axes) ...@@ -364,7 +364,7 @@ TEST(axes, hash_axes)
m2[axes] = 1; m2[axes] = 1;
} }
TEST(axes, reaxe_0d_to_1d) TEST(axes, DISABLED_reaxe_0d_to_1d)
{ {
TensorDescription td{}; TensorDescription td{};
ngraph::ndarray x = random(td); ngraph::ndarray x = random(td);
...@@ -382,7 +382,7 @@ TEST(axes, reaxe_0d_to_1d) ...@@ -382,7 +382,7 @@ TEST(axes, reaxe_0d_to_1d)
FAIL(); FAIL();
} }
TEST(axes, reaxe_0d_to_2d) TEST(axes, DISABLED_reaxe_0d_to_2d)
{ {
// td = TensorDescription(axes=()) // td = TensorDescription(axes=())
// x = random(td) // x = random(td)
...@@ -407,7 +407,7 @@ TEST(axes, reaxe_0d_to_2d) ...@@ -407,7 +407,7 @@ TEST(axes, reaxe_0d_to_2d)
// I started refactoring into smaller pieces as seen in tests above, but // I started refactoring into smaller pieces as seen in tests above, but
// stopped ... // stopped ...
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
TEST(axes, simple_tensors) TEST(axes, DISABLED_simple_tensors)
{ {
// # A simple vector // # A simple vector
// td1 = TensorDescription(axes=[ax_A]) // td1 = TensorDescription(axes=[ax_A])
...@@ -582,7 +582,7 @@ TEST(axes, axes_map) ...@@ -582,7 +582,7 @@ TEST(axes, axes_map)
// assert axes_after == axes_map.map_axes(axes_before) // assert axes_after == axes_map.map_axes(axes_before)
} }
TEST(axes, axes_map_immutable) TEST(axes, DISABLED_axes_map_immutable)
{ {
FAIL(); FAIL();
// axes_map = AxesMap({}) // axes_map = AxesMap({})
...@@ -591,7 +591,7 @@ TEST(axes, axes_map_immutable) ...@@ -591,7 +591,7 @@ TEST(axes, axes_map_immutable)
// axes_map["x"] = "y" // axes_map["x"] = "y"
} }
TEST(axes, axes_map_init_from_axes) TEST(axes, DISABLED_axes_map_init_from_axes)
{ {
FAIL(); FAIL();
// axes_map = AxesMap({ng.make_axis(1, name="aaa"): ng.make_axis(1, name="zzz")}) // axes_map = AxesMap({ng.make_axis(1, name="aaa"): ng.make_axis(1, name="zzz")})
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <iostream>
#include <chrono> #include <chrono>
#include <iostream>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -22,4 +22,6 @@ ...@@ -22,4 +22,6 @@
using namespace ngraph; using namespace ngraph;
TEST(names, name) {} TEST(names, name)
{
}
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include <memory> #include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -134,7 +134,9 @@ TEST(util, contains) ...@@ -134,7 +134,9 @@ TEST(util, contains)
EXPECT_FALSE(contains(v1, 8)); EXPECT_FALSE(contains(v1, 8));
} }
TEST(util, remove_from) {} TEST(util, remove_from)
{
}
TEST(util, reduce) TEST(util, reduce)
{ {
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment