Commit 973b3a0e authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #74 from NervanaSystems/cyphers/names

Cyphers/names
parents fac27c37 fd881acc
...@@ -39,6 +39,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-padded") ...@@ -39,6 +39,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-padded")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-potentially-evaluated-expression") # Triggers false alarms on typeid set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-potentially-evaluated-expression") # Triggers false alarms on typeid
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-weak-vtables") # Not ready for this yet
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-conversion") # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-conversion")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-float-equal") # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-float-equal")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-duplicate-enum") # from numpy # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-duplicate-enum") # from numpy
......
...@@ -15,20 +15,33 @@ ...@@ -15,20 +15,33 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <vector>
#include <set> #include <set>
#include <vector>
// Names for types that aren't worth giving their own classes // Names for types that aren't worth giving their own classes
namespace ngraph namespace ngraph
{ {
class Node; class Node;
class Parameter; class Parameter;
class ValueType;
template <typename T, typename... A>
std::shared_ptr<T> node(A&&... args)
{
return std::make_shared<T>(args...);
}
/// Zero or more value types
using ValueTypes = std::vector<std::shared_ptr<ValueType>>;
/// Zero or more nodes /// Zero or more nodes
using Nodes = std::vector<std::shared_ptr<Node>>; using Nodes = std::vector<std::shared_ptr<Node>>;
/// A set of indices, for example, reduction axes /// A sequence of axes
using IndexSet = std::set<size_t>; using AxisVector = std::vector<size_t>;
/// A set of axes, for example, reduction axes
using AxisSet = std::set<size_t>;
/// A list of parameters /// A list of parameters
using Parameters = std::vector<std::shared_ptr<Parameter>>; using Parameters = std::vector<std::shared_ptr<Parameter>>;
......
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include "except.hpp"
namespace ngraph namespace ngraph
{ {
namespace element namespace element
...@@ -41,42 +43,77 @@ namespace ngraph ...@@ -41,42 +43,77 @@ namespace ngraph
bool operator==(const Type& other) const; bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); } bool operator!=(const Type& other) const { return !(*this == other); }
private: private:
static std::map<std::string, Type> m_element_list; static std::map<std::string, Type> m_element_list;
size_t m_bitwidth; size_t m_bitwidth;
bool m_is_float; bool m_is_float;
bool m_is_signed; bool m_is_signed;
const std::string m_cname; const std::string& m_cname;
}; };
// Provides a compile-time name for a C++ type.
// Used in TraitedType for the string that supplies the C++ type name during code generation,
// so it needs to be a valid C++ name.
template <typename T>
const char* traited_type_name()
{
throw ngraph_error("Unknown type");
}
// Define a type string for a type T. Will make traited_type_name<T>() return "T"
#define NGRAPH_DEFINE_TRAITED_TYPE_NAME(T) \
template <> \
constexpr const char* traited_type_name<T>() \
{ \
return #T; \
}
// Literals (and probably other things we don't know about yet) need to have their C++ types // Literals (and probably other things we don't know about yet) need to have their C++ types
// and element types coordinated. Every element type corresponds to a TraitedType which provides // and element types coordinated. Every element type corresponds to a TraitedType which provides
// access to both the instance and the C++ type used to hold the value during compilation. // access to both the instance and the C++ type used to hold the value during compilation.
template <typename T> template <typename T>
class TraitedType : public Type class TraitedType : public Type
{ {
public: protected:
// This is the C++ type used to hold a value of this element type during compilation TraitedType()
using ctype = T;
// This is a reference to an instance of this element type.
static const TraitedType<T>& type;
TraitedType(const std::string& cname)
: Type(sizeof(T) * 8, : Type(sizeof(T) * 8,
std::is_floating_point<T>::value, std::is_floating_point<T>::value,
std::is_signed<T>::value, std::is_signed<T>::value,
cname) traited_type_name<T>())
{ {
} }
public:
// This is the C++ type used to hold a value of this element type during compilation
using type = T;
// This returns a reference to an instance of this element type.
static const TraitedType<T>& element_type()
{
static TraitedType<T> t;
return t;
}
}; };
// Human-readable names for the element types NGRAPH_DEFINE_TRAITED_TYPE_NAME(float)
using Float = TraitedType<float>; using Float32 = TraitedType<float>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(int8_t)
using Int8 = TraitedType<int8_t>; using Int8 = TraitedType<int8_t>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(int32_t)
using Int32 = TraitedType<int32_t>; using Int32 = TraitedType<int32_t>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(int64_t)
using Int64 = TraitedType<int64_t>; using Int64 = TraitedType<int64_t>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(uint8_t)
using UInt8 = TraitedType<uint8_t>; using UInt8 = TraitedType<uint8_t>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(uint32_t)
using UInt32 = TraitedType<uint32_t>; using UInt32 = TraitedType<uint32_t>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(uint64_t)
using UInt64 = TraitedType<uint64_t>; using UInt64 = TraitedType<uint64_t>;
} }
} }
...@@ -21,20 +21,22 @@ ...@@ -21,20 +21,22 @@
namespace ngraph namespace ngraph
{ {
/** /// A user-defined function.
** A user-defined function.
**/
class Function class Function
{ {
public: public:
Function(const Node::ptr& result, Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters); const std::vector<std::shared_ptr<Parameter>>& parameters);
Node::ptr result() { return m_result; } std::shared_ptr<Node> get_result() { return m_result; }
Parameter::ptr parameter(size_t i) { return m_parameters[i]; } const std::vector<std::shared_ptr<Parameter>> get_parameters() const
std::string name() const { return m_name; } {
return m_parameters;
}
std::string get_name() const { return m_name; }
protected: protected:
Node::ptr m_result; std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters; std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters;
std::string m_name; std::string m_name;
}; };
...@@ -42,10 +44,10 @@ namespace ngraph ...@@ -42,10 +44,10 @@ namespace ngraph
namespace op namespace op
{ {
std::shared_ptr<Function> std::shared_ptr<Function>
function(const Node::ptr& result, function(const std::shared_ptr<Node>& result,
const std::initializer_list<std::shared_ptr<Parameter>>& parameters); const std::initializer_list<std::shared_ptr<Parameter>>& parameters);
std::shared_ptr<Function> std::shared_ptr<Function>
function(const Node::ptr& result, function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters); const std::vector<std::shared_ptr<Parameter>>& parameters);
} }
} }
...@@ -17,9 +17,10 @@ ...@@ -17,9 +17,10 @@
size_t ngraph::Node::m_next_instance_id = 0; size_t ngraph::Node::m_next_instance_id = 0;
ngraph::Node::Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type) ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments,
: TypedValueMixin(type) std::shared_ptr<ValueType> value_type)
, m_arguments(arguments) : m_arguments(arguments)
, m_value_type(value_type)
, m_instance_id(m_next_instance_id++) , m_instance_id(m_next_instance_id++)
{ {
// Add this node as a user of each argument. // Add this node as a user of each argument.
...@@ -47,15 +48,15 @@ namespace ngraph ...@@ -47,15 +48,15 @@ namespace ngraph
auto parameter_tmp = dynamic_cast<const ngraph::Op*>(&node); auto parameter_tmp = dynamic_cast<const ngraph::Op*>(&node);
if (op_tmp) if (op_tmp)
{ {
out << "Op(" << op_tmp->node_id() << ")"; out << "Op(" << op_tmp->get_node_id() << ")";
} }
else if (parameter_tmp) else if (parameter_tmp)
{ {
out << "Parameter(" << parameter_tmp->node_id() << ")"; out << "Parameter(" << parameter_tmp->get_node_id() << ")";
} }
else else
{ {
out << "Node(" << node.node_id() << ")"; out << "Node(" << node.get_node_id() << ")";
} }
return out; return out;
} }
......
...@@ -20,27 +20,32 @@ ...@@ -20,27 +20,32 @@
#include <iostream> #include <iostream>
#include "type.hpp"
#include "common.hpp" #include "common.hpp"
#include "type.hpp"
namespace ngraph namespace ngraph
{ {
class Op; class Op;
/** /// Nodes are the backbone of the graph of Value dataflow. Every node has
** Nodes are the backbone of the graph of Value dataflow. Every node has /// zero or more nodes as arguments and one value, which is either a tensor
** zero or more nodes as arguments and one value, which is either a tensor /// view or a (possibly empty) tuple of values.
** view or a (possibly empty) tuple of values. class Node : public std::enable_shared_from_this<Node>
**/
class Node : public TypedValueMixin, public std::enable_shared_from_this<Node>
{ {
public:
using ptr = std::shared_ptr<Node>;
protected: protected:
Node(const Nodes& arguments, ValueType::ptr type = nullptr); Node(const Nodes& arguments, std::shared_ptr<ValueType> value_type = nullptr);
Node()
: Node({}, nullptr)
{
}
Node(std::shared_ptr<ValueType> value_type)
: Node({}, value_type)
{
}
virtual ~Node() {} virtual ~Node() {}
public: public:
/// A "one-liner" describing this node. /// A "one-liner" describing this node.
virtual std::string description() const = 0; virtual std::string description() const = 0;
...@@ -48,38 +53,48 @@ namespace ngraph ...@@ -48,38 +53,48 @@ namespace ngraph
/// Propagate types and check arguments for consistency /// Propagate types and check arguments for consistency
virtual void propagate_types() = 0; virtual void propagate_types() = 0;
const Nodes& arguments() const { return m_arguments; } const Nodes& get_arguments() const { return m_arguments; }
const std::multiset<Node*>& users() const { return m_users; } const std::multiset<Node*>& users() const { return m_users; }
std::string name() const { return m_name; } std::string get_name() const { return m_name; }
void name(const std::string& name) { m_name = name; } void set_name(const std::string& name) { m_name = name; }
virtual std::string node_id() const = 0; virtual std::string get_node_id() const = 0;
/** /// Return true if this has the same implementing class as node. This
** Return true if this has the same implementing class as node. This /// will be used by the pattern matcher when comparing a pattern
** will be used by the pattern matcher when comparing a pattern /// graph against the graph.
** graph against the graph. bool is_same_op_type(const std::shared_ptr<Node>& node) const
**/
bool is_same_op_type(const Node::ptr& node) const
{ {
return typeid(*this) == typeid(*node.get()); return typeid(*this) == typeid(*node.get());
} }
std::shared_ptr<ValueType> get_value_type() { return m_value_type; }
const std::shared_ptr<ValueType> get_value_type() const { return m_value_type; }
void set_value_type(const element::Type& element_type, const Shape& shape)
{
m_value_type = std::make_shared<TensorViewType>(element_type, shape);
}
void set_value_type(const std::shared_ptr<ValueType>& value_type)
{
m_value_type = value_type;
}
bool is_op() const; bool is_op() const;
bool is_parameter() const; bool is_parameter() const;
size_t instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
protected: protected:
Nodes m_arguments; Nodes m_arguments;
std::shared_ptr<ValueType> m_value_type;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
std::string m_name; std::string m_name;
size_t m_instance_id; size_t m_instance_id;
static size_t m_next_instance_id; static size_t m_next_instance_id;
}; };
using node_ptr = std::shared_ptr<Node>;
} }
...@@ -24,74 +24,87 @@ namespace ngraph ...@@ -24,74 +24,87 @@ namespace ngraph
{ {
namespace op namespace op
{ {
Node::ptr abs(const Node::ptr& arg); std::shared_ptr<Node> abs(const std::shared_ptr<Node>& arg);
Node::ptr add(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> add(const std::shared_ptr<Node>& arg0,
Node::ptr ceiling(const Node::ptr& arg0, const Node::ptr& arg1); const std::shared_ptr<Node>& arg1);
//Node::ptr convert(); std::shared_ptr<Node> ceiling(const std::shared_ptr<Node>& arg0,
//Node::ptr convolution(); const std::shared_ptr<Node>& arg1);
Node::ptr divide(const Node::ptr& arg0, const Node::ptr& arg1); //std::shared_ptr<Node> convert();
Node::ptr equal(const Node::ptr& arg0, const Node::ptr& arg1); //std::shared_ptr<Node> convolution();
Node::ptr exponential(const Node::ptr& arg0); std::shared_ptr<Node> divide(const std::shared_ptr<Node>& arg0,
Node::ptr floor(const Node::ptr& arg0, const Node::ptr& arg1); const std::shared_ptr<Node>& arg1);
//Node::ptr get_tuple_element(); std::shared_ptr<Node> equal(const std::shared_ptr<Node>& arg0,
Node::ptr greater(const Node::ptr& arg0, const Node::ptr& arg1); const std::shared_ptr<Node>& arg1);
//Node::ptr greater_equal(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> exp(const std::shared_ptr<Node>& arg0);
Node::ptr less(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> floor(const std::shared_ptr<Node>& arg0,
//Node::ptr less_equal(const Node::ptr& arg0, const Node::ptr& arg1); const std::shared_ptr<Node>& arg1);
Node::ptr log(const Node::ptr& arg0); //std::shared_ptr<Node> get_tuple_element();
//Node::ptr logical(); and, or, not std::shared_ptr<Node> greater(const std::shared_ptr<Node>& arg0,
Node::ptr maximum(const Node::ptr& arg0, const Node::ptr& arg1); const std::shared_ptr<Node>& arg1);
Node::ptr minimum(const Node::ptr& arg0, const Node::ptr& arg1); //std::shared_ptr<Node> greater_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
Node::ptr multiply(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> less(const std::shared_ptr<Node>& arg0,
Node::ptr negate(const Node::ptr& arg0); const std::shared_ptr<Node>& arg1);
//Node::ptr pad(); //std::shared_ptr<Node> less_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
Node::ptr power(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> log(const std::shared_ptr<Node>& arg0);
//Node::ptr reduce(); //std::shared_ptr<Node> logical(); and, or, not
// Node::ptr reduce_window(); std::shared_ptr<Node> maximum(const std::shared_ptr<Node>& arg0,
Node::ptr remainder(const Node::ptr& arg0, const Node::ptr& arg1); const std::shared_ptr<Node>& arg1);
Node::ptr reshape(const Node::ptr& arg0, const Shape& shape); std::shared_ptr<Node> minimum(const std::shared_ptr<Node>& arg0,
//Node::ptr reverse(); const std::shared_ptr<Node>& arg1);
//Node::ptr rng(); std::shared_ptr<Node> multiply(const std::shared_ptr<Node>& arg0,
//Node::ptr select(); const std::shared_ptr<Node>& arg1);
//Node::ptr select_scatter(); std::shared_ptr<Node> negative(const std::shared_ptr<Node>& arg0);
//Node::ptr slice(); //std::shared_ptr<Node> pad();
Node::ptr subtract(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> power(const std::shared_ptr<Node>& arg0,
//Node::ptr transpose(); const std::shared_ptr<Node>& arg1);
//Node::ptr while(); //std::shared_ptr<Node> reduce();
// std::shared_ptr<Node> reduce_window();
std::shared_ptr<Node> remainder(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> reshape(const std::shared_ptr<Node>& arg0, const Shape& shape);
//std::shared_ptr<Node> reverse();
//std::shared_ptr<Node> rng();
//std::shared_ptr<Node> select();
//std::shared_ptr<Node> select_scatter();
//std::shared_ptr<Node> slice();
std::shared_ptr<Node> subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> transpose();
//std::shared_ptr<Node> while();
} }
/** /// Op nodes are nodes whose value is the result of some operation
** Op nodes are nodes whose value is the result of some operation /// applied to its arguments. For calls to user functions, the op will
** applied to its arguments. For calls to user functions, the op will /// reference the user function.
** reference the user function.
**/
class Op : public Node class Op : public Node
{ {
public: public:
Op(const std::vector<Node::ptr>& arguments) Op(const std::vector<std::shared_ptr<Node>>& arguments)
: Node(arguments, nullptr) : Node(arguments)
{ {
} }
virtual std::string op_class_name() const = 0; Op()
virtual std::string node_id() const; : Node()
{
}
virtual std::string get_op_class_name() const = 0;
virtual std::string get_node_id() const override;
}; };
/** /// A FunctionOp invokes a function on node arguments. In addition to the argument
** A FunctionOp invokes a function on node arguments. In addition to the argument /// we need to preserve the function.
** we need to preserve the function.
**/
class FunctionOp : public Op class FunctionOp : public Op
{ {
virtual std::string description() const override { return "FunctionOp"; } virtual std::string description() const override { return "FunctionOp"; }
protected: protected:
Node::ptr m_function; std::shared_ptr<Node> m_function;
}; };
/** /// The is an operation we handle directly, i.e. all type checking, etc.
** The is an operation we handle directly, i.e. all type checking, etc. /// are defined in C++ rather than in terms of ngraph operations.
** are defined in C++ rather than in terms of ngraph operations.
**/
class BuiltinOp : public Op class BuiltinOp : public Op
{ {
public: public:
...@@ -100,8 +113,9 @@ namespace ngraph ...@@ -100,8 +113,9 @@ namespace ngraph
// TODO: Implement for each op. This enables graphs to be built for now. // TODO: Implement for each op. This enables graphs to be built for now.
virtual void propagate_types() override {} virtual void propagate_types() override {}
protected: protected:
BuiltinOp(const std::vector<Node::ptr>& args) BuiltinOp(const std::vector<std::shared_ptr<Node>>& args)
: Op(args) : Op(args)
{ {
} }
...@@ -110,204 +124,204 @@ namespace ngraph ...@@ -110,204 +124,204 @@ namespace ngraph
class AbsOp : public BuiltinOp class AbsOp : public BuiltinOp
{ {
public: public:
AbsOp(const Node::ptr& arg0) AbsOp(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0}) : BuiltinOp({arg0})
{ {
} }
virtual std::string op_class_name() const override { return "abs"; } virtual std::string get_op_class_name() const override { return "abs"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class AddOp : public BuiltinOp class AddOp : public BuiltinOp
{ {
public: public:
AddOp(const Node::ptr& arg0, const Node::ptr& arg1) AddOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "add"; } virtual std::string get_op_class_name() const override { return "add"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class CeilingOp : public BuiltinOp class CeilingOp : public BuiltinOp
{ {
public: public:
CeilingOp(const Node::ptr& arg0, const Node::ptr& arg1) CeilingOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "ceiling"; } virtual std::string get_op_class_name() const override { return "ceiling"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class DivideOp : public BuiltinOp class DivideOp : public BuiltinOp
{ {
public: public:
DivideOp(const Node::ptr& arg0, const Node::ptr& arg1) DivideOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "divide"; } virtual std::string get_op_class_name() const override { return "divide"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class EqualOp : public BuiltinOp class EqualOp : public BuiltinOp
{ {
public: public:
EqualOp(const Node::ptr& arg0, const Node::ptr& arg1) EqualOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "equal"; } virtual std::string get_op_class_name() const override { return "equal"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class ExponentialOp : public BuiltinOp class ExpOp : public BuiltinOp
{ {
public: public:
ExponentialOp(const Node::ptr& arg0) ExpOp(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0}) : BuiltinOp({arg0})
{ {
} }
virtual std::string op_class_name() const override { return "exp"; } virtual std::string get_op_class_name() const override { return "exp"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class FloorOp : public BuiltinOp class FloorOp : public BuiltinOp
{ {
public: public:
FloorOp(const Node::ptr& arg0, const Node::ptr& arg1) FloorOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "floor"; } virtual std::string get_op_class_name() const override { return "floor"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class GreaterOp : public BuiltinOp class GreaterOp : public BuiltinOp
{ {
public: public:
GreaterOp(const Node::ptr& arg0, const Node::ptr& arg1) GreaterOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "greater"; } virtual std::string get_op_class_name() const override { return "greater"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class LessOp : public BuiltinOp class LessOp : public BuiltinOp
{ {
public: public:
LessOp(const Node::ptr& arg0, const Node::ptr& arg1) LessOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "less"; } virtual std::string get_op_class_name() const override { return "less"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class LogOp : public BuiltinOp class LogOp : public BuiltinOp
{ {
public: public:
LogOp(const Node::ptr& arg0) LogOp(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0}) : BuiltinOp({arg0})
{ {
} }
virtual std::string op_class_name() const override { return "log"; } virtual std::string get_op_class_name() const override { return "log"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class MaximumOp : public BuiltinOp class MaximumOp : public BuiltinOp
{ {
public: public:
MaximumOp(const Node::ptr& arg0, const Node::ptr& arg1) MaximumOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "max"; } virtual std::string get_op_class_name() const override { return "max"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class MinimumOp : public BuiltinOp class MinimumOp : public BuiltinOp
{ {
public: public:
MinimumOp(const Node::ptr& arg0, const Node::ptr& arg1) MinimumOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "min"; } virtual std::string get_op_class_name() const override { return "min"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class MultiplyOp : public BuiltinOp class MultiplyOp : public BuiltinOp
{ {
public: public:
MultiplyOp(const Node::ptr& arg0, const Node::ptr& arg1) MultiplyOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "multiply"; } virtual std::string get_op_class_name() const override { return "multiply"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class NegateOp : public BuiltinOp class NegativeOp : public BuiltinOp
{ {
public: public:
NegateOp(const Node::ptr& arg0) NegativeOp(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0}) : BuiltinOp({arg0})
{ {
} }
virtual std::string op_class_name() const override { return "negate"; } virtual std::string get_op_class_name() const override { return "negative"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class PowerOp : public BuiltinOp class PowerOp : public BuiltinOp
{ {
public: public:
PowerOp(const Node::ptr& arg0, const Node::ptr& arg1) PowerOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "power"; } virtual std::string get_op_class_name() const override { return "power"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class RemainderOp : public BuiltinOp class RemainderOp : public BuiltinOp
{ {
public: public:
RemainderOp(const Node::ptr& arg0, const Node::ptr& arg1) RemainderOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "remainder"; } virtual std::string get_op_class_name() const override { return "remainder"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class ReshapeOp : public BuiltinOp class ReshapeOp : public BuiltinOp
{ {
public: public:
ReshapeOp(const Node::ptr& arg0, const Shape& shape) ReshapeOp(const std::shared_ptr<Node>& arg0, const Shape& shape)
: BuiltinOp({arg0}) : BuiltinOp({arg0})
, m_shape(shape) , m_shape(shape)
{ {
} }
virtual std::string op_class_name() const override { return "reshape"; } virtual std::string get_op_class_name() const override { return "reshape"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
Shape m_shape; Shape m_shape;
...@@ -316,12 +330,12 @@ namespace ngraph ...@@ -316,12 +330,12 @@ namespace ngraph
class SubtractOp : public BuiltinOp class SubtractOp : public BuiltinOp
{ {
public: public:
SubtractOp(const Node::ptr& arg0, const Node::ptr& arg1) SubtractOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "subtract"; } virtual std::string get_op_class_name() const override { return "subtract"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
} }
...@@ -19,33 +19,33 @@ namespace ngraph ...@@ -19,33 +19,33 @@ namespace ngraph
class BroadcastOp : public BuiltinOp class BroadcastOp : public BuiltinOp
{ {
public: public:
using Axes = std::vector<size_t>; ///
/// @param arg The tensor view to be broadcast.
/** /// @param shape The shape of the result
** /param arg The tensor view to be broadcast. /// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** /param shape The shape of the result /// the remaining axes in shape must be the same as the shape of arg.
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast. ///
** the remaining axes in shape must be the same as the shape of arg. BroadcastOp(const std::shared_ptr<Node>& arg,
**/ const Shape& shape,
BroadcastOp(const Node::ptr& arg, const Shape& shape, const Axes& broadcast_axes) const AxisSet& broadcast_axes)
: BuiltinOp({arg}) : BuiltinOp({arg})
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
{ {
} }
virtual std::string op_class_name() const override { return "broadcast"; } virtual std::string get_op_class_name() const override { return "broadcast"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
Shape m_shape; Shape m_shape;
Axes m_broadcast_axes; AxisSet m_broadcast_axes;
}; };
namespace op namespace op
{ {
Node::ptr broadcast(const Node::ptr& tensor, std::shared_ptr<Node> broadcast(const std::shared_ptr<Node>& tensor,
const Shape& shape, const Shape& shape,
const BroadcastOp::Axes&& broadcast_axes); AxisSet&& broadcast_axes);
} }
} }
...@@ -18,18 +18,18 @@ namespace ngraph ...@@ -18,18 +18,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
Node::ptr concatenate(const Nodes& args); std::shared_ptr<Node> concatenate(const Nodes& args);
} }
class ConcatenateOp : public BuiltinOp class ConcatOp : public BuiltinOp
{ {
public: public:
ConcatenateOp(const Nodes& args) ConcatOp(const Nodes& args)
: BuiltinOp(args) : BuiltinOp(args)
{ {
} }
virtual std::string op_class_name() const override { return "concatenate"; } virtual std::string get_op_class_name() const override { return "concatenate"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
} }
...@@ -21,10 +21,10 @@ ...@@ -21,10 +21,10 @@
namespace ngraph namespace ngraph
{ {
// Defines methods to all constant scalars // Defines methods to all constant scalars
class ScalarConstantBaseOp : public Node class ScalarConstantBase : public Node
{ {
protected: protected:
ScalarConstantBaseOp(const std::shared_ptr<TensorViewType>& type) ScalarConstantBase(const std::shared_ptr<TensorViewType>& type)
: Node({}, type) : Node({}, type)
{ {
} }
...@@ -35,47 +35,39 @@ namespace ngraph ...@@ -35,47 +35,39 @@ namespace ngraph
// Implement a constant scalar for each element type. // Implement a constant scalar for each element type.
// The static make method takes a // The static make method takes a
template <typename T> template <typename T>
class ScalarConstantOp : public ScalarConstantBaseOp class ScalarConstant : public ScalarConstantBase
{ {
public: public:
// The ngraph element type // The ngraph element type
using element_type = T; using element_type = T;
// The C++ type that holds the element type // The C++ type that holds the element type
using ctype = typename T::ctype; using type = typename T::type;
ScalarConstantOp(typename T::ctype value) ScalarConstant(typename T::type value)
: ScalarConstantBaseOp(std::make_shared<TensorViewType>(T::type, Shape{})) : ScalarConstantBase(std::make_shared<TensorViewType>(T::element_type(), Shape{}))
, m_value(value) , m_value(value)
{ {
} }
virtual std::string description() const override { return "ConstantScalar"; } virtual std::string description() const override { return "ScalarConstant"; }
virtual std::string node_id() const override virtual std::string get_node_id() const override
{ {
std::stringstream ss; std::stringstream ss;
ss << description() << "_" << node_id(); ss << description() << "_" /* << node_id() */;
return ss.str(); return ss.str();
} }
typename T::ctype value() const { return m_value; } typename T::type get_value() const { return m_value; }
// Make a constant from any value that can be converted to the C++ type we use
// to represent the values.
template <typename U>
static std::shared_ptr<ScalarConstantOp<T>> make(U value)
{
return std::make_shared<ScalarConstantOp<T>>(value);
}
protected: protected:
typename T::ctype m_value; typename T::type m_value;
}; };
using FloatScalarConstantOp = ScalarConstantOp<element::Float>; using Float32ScalarConstant = ScalarConstant<element::Float32>;
using Int8ScalarConstantOp = ScalarConstantOp<element::Int8>; using Int8ScalarConstant = ScalarConstant<element::Int8>;
using Int32ScalarConstantOp = ScalarConstantOp<element::Int32>; using Int32ScalarConstant = ScalarConstant<element::Int32>;
using Int64ScalarConstantOp = ScalarConstantOp<element::Int64>; using Int64ScalarConstant = ScalarConstant<element::Int64>;
using UInt8ScalarConstantOp = ScalarConstantOp<element::UInt8>; using UInt8ScalarConstant = ScalarConstant<element::UInt8>;
using UInt32ScalarConstantOp = ScalarConstantOp<element::UInt32>; using UInt32ScalarConstant = ScalarConstant<element::UInt32>;
using UInt64ScalarConstantOp = ScalarConstantOp<element::UInt64>; using UInt64ScalarConstant = ScalarConstant<element::UInt64>;
} }
...@@ -16,25 +16,25 @@ ...@@ -16,25 +16,25 @@
namespace ngraph namespace ngraph
{ {
class ConvertOp : public BuiltinOp class ConvertOp : public BuiltinOp
{ {
public: public:
ConvertOp(const Node::ptr& arg, const ngraph::element::Type& element_type) ConvertOp(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: BuiltinOp({arg}) : BuiltinOp({arg})
, m_element_type(element_type) , m_element_type(element_type)
{ {
} }
virtual std::string op_class_name() const override { return "convert"; } virtual std::string get_op_class_name() const override { return "convert"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
const ngraph::element::Type& m_element_type; const ngraph::element::Type& m_element_type;
}; };
namespace op namespace op
{ {
std::shared_ptr<ngraph::ConvertOp> convert(const Node::ptr& arg, const ngraph::element::Type& element_type); std::shared_ptr<ngraph::ConvertOp> convert(const std::shared_ptr<Node>& arg,
const ngraph::element::Type& element_type);
} }
} }
...@@ -20,17 +20,18 @@ namespace ngraph ...@@ -20,17 +20,18 @@ namespace ngraph
{ {
public: public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotOp(const Node::ptr& arg0, const Node::ptr& arg1) DotOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "dot"; } virtual std::string get_op_class_name() const override { return "dot"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
namespace op namespace op
{ {
Node::ptr dot(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
} }
} }
...@@ -21,11 +21,11 @@ namespace ngraph ...@@ -21,11 +21,11 @@ namespace ngraph
{ {
class Function; class Function;
/** ///
** Parameters are nodes that represent the arguments that will be passed to user-defined functions. /// Parameters are nodes that represent the arguments that will be passed to user-defined functions.
** Function creation requires a sequence of parameters. /// Function creation requires a sequence of parameters.
** Basic graph operations do not need parameters attached to a function. /// Basic graph operations do not need parameters attached to a function.
**/ ///
class Parameter : public Node class Parameter : public Node
{ {
friend class Function; friend class Function;
...@@ -36,11 +36,12 @@ namespace ngraph ...@@ -36,11 +36,12 @@ namespace ngraph
void assign_function(Function* function, size_t index); void assign_function(Function* function, size_t index);
public: public:
Parameter(const ValueType::ptr& value_type); Parameter(const std::shared_ptr<ValueType>& value_type);
Parameter(const ngraph::element::Type element_type, const Shape& shape);
std::string description() const override { return "Parameter"; } std::string description() const override { return "Parameter"; }
virtual void propagate_types() override; virtual void propagate_types() override;
virtual std::string node_id() const override; virtual std::string get_node_id() const override;
protected: protected:
Function* m_function; Function* m_function;
...@@ -50,9 +51,10 @@ namespace ngraph ...@@ -50,9 +51,10 @@ namespace ngraph
namespace op namespace op
{ {
/// Factory for frameworks /// Factory for frameworks
std::shared_ptr<ngraph::Parameter> parameter(const ValueType::ptr& value_type = nullptr); std::shared_ptr<ngraph::Parameter>
parameter(const std::shared_ptr<ValueType>& value_type = nullptr);
/// Convenience factory for tests /// Convenience factory for tests
std::shared_ptr<ngraph::Parameter> parameter(const ngraph::element::Type element_type, std::shared_ptr<ngraph::Parameter> parameter(const element::Type element_type,
const Shape& shape); const Shape& shape);
} }
} }
...@@ -18,7 +18,7 @@ namespace ngraph ...@@ -18,7 +18,7 @@ namespace ngraph
{ {
namespace op namespace op
{ {
Node::ptr tuple(const Nodes& args); std::shared_ptr<Node> tuple(const Nodes& args);
} }
class TupleOp : public BuiltinOp class TupleOp : public BuiltinOp
...@@ -29,7 +29,7 @@ namespace ngraph ...@@ -29,7 +29,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "tuple"; } virtual std::string get_op_class_name() const override { return "tuple"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
} }
...@@ -24,9 +24,7 @@ namespace ngraph ...@@ -24,9 +24,7 @@ namespace ngraph
class Shape class Shape
{ {
public: public:
/** /// @param sizes A sequence of sizes.
** \param sizes A sequence of sizes.
**/
Shape(const std::initializer_list<size_t>& sizes) Shape(const std::initializer_list<size_t>& sizes)
: m_sizes(sizes) : m_sizes(sizes)
{ {
...@@ -37,12 +35,11 @@ namespace ngraph ...@@ -37,12 +35,11 @@ namespace ngraph
{ {
} }
/** /// Conversion to a vector of sizes.
** Conversion to a vector of sizes.
**/
operator const std::vector<size_t>&() const { return m_sizes; } operator const std::vector<size_t>&() const { return m_sizes; }
bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; } bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; }
bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; } bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; }
protected: protected:
std::vector<size_t> m_sizes; std::vector<size_t> m_sizes;
}; };
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "node.hpp"
#include "topological_sort.hpp" #include "topological_sort.hpp"
#include "node.hpp"
#include "util.hpp" #include "util.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -21,7 +21,7 @@ using namespace std; ...@@ -21,7 +21,7 @@ using namespace std;
void ngraph::TopologicalSort::promote_node(Node* n) void ngraph::TopologicalSort::promote_node(Node* n)
{ {
for (auto dn=m_dependent_nodes.begin(); dn!=m_dependent_nodes.end(); dn++) for (auto dn = m_dependent_nodes.begin(); dn != m_dependent_nodes.end(); dn++)
{ {
if (dn->first > 0) // Skip zero as they should never be promoted if (dn->first > 0) // Skip zero as they should never be promoted
{ {
...@@ -30,7 +30,7 @@ void ngraph::TopologicalSort::promote_node(Node* n) ...@@ -30,7 +30,7 @@ void ngraph::TopologicalSort::promote_node(Node* n)
{ {
// found the node // found the node
dn->second.erase(it); dn->second.erase(it);
m_dependent_nodes[dn->first-1].push_back(n); m_dependent_nodes[dn->first - 1].push_back(n);
} }
} }
} }
...@@ -38,9 +38,8 @@ void ngraph::TopologicalSort::promote_node(Node* n) ...@@ -38,9 +38,8 @@ void ngraph::TopologicalSort::promote_node(Node* n)
void ngraph::TopologicalSort::process(node_ptr p) void ngraph::TopologicalSort::process(node_ptr p)
{ {
traverse_nodes(p, [&](node_ptr node) traverse_nodes(p, [&](node_ptr node) {
{ list<Node*>& node_list = m_dependent_nodes[node->get_arguments().size()];
list<Node*>& node_list = m_dependent_nodes[node->arguments().size()];
node_list.push_back(node.get()); node_list.push_back(node.get());
}); });
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include <memory>
#include <map>
#include <list> #include <list>
#include <map>
#include <memory>
#include <vector>
namespace ngraph namespace ngraph
{ {
......
...@@ -25,123 +25,61 @@ namespace ngraph ...@@ -25,123 +25,61 @@ namespace ngraph
class TensorViewType; class TensorViewType;
class TupleType; class TupleType;
/** /// ValueType is
** ValueType is /// TensorViewType
** TensorViewType /// | TupleType(ValueType[])
** | TupleType(ValueType[])
**/
class ValueType class ValueType
{ {
public: public:
/**
** Preferred handle
**/
using ptr = std::shared_ptr<ValueType>;
virtual ~ValueType() {} virtual ~ValueType() {}
virtual bool operator==(const ValueType::ptr& that) const = 0; virtual bool operator==(const std::shared_ptr<ValueType>& that) const = 0;
bool operator!=(const ValueType::ptr& that) const { return !(*this == that); } bool operator!=(const std::shared_ptr<ValueType>& that) const { return !(*this == that); }
}; };
/** /// Describes a tensor view; an element type and a shape.
** Describes a tensor view; an element type and a shape.
**/
class TensorViewType : public ValueType class TensorViewType : public ValueType
{ {
public: public:
/** /// /param element_type The type of the tensor elements.
** Preferred handle /// /param shape The shape of the tensor.
**/
using ptr = std::shared_ptr<TensorViewType>;
/**
** /param element_type The type of the tensor elements.
** /param shape The shape of the tensor.
**/
TensorViewType(const element::Type& element_type, const Shape& shape) TensorViewType(const element::Type& element_type, const Shape& shape)
: m_element_type(element_type) : m_element_type(element_type)
, m_shape(shape) , m_shape(shape)
{ {
} }
const element::Type& element_type() const { return m_element_type; } const element::Type& get_element_type() const { return m_element_type; }
const Shape& shape() const { return m_shape; } const Shape& get_shape() const { return m_shape; }
virtual bool operator==(const ValueType::ptr& that) const override; virtual bool operator==(const std::shared_ptr<ValueType>& that) const override;
protected: protected:
const element::Type& m_element_type; const element::Type& m_element_type;
Shape m_shape; Shape m_shape;
}; };
/** /// Describes a tuple of values; a vector of types
** Describes a tuple of values; a vector of types
**/
class TupleType : public ValueType class TupleType : public ValueType
{ {
public: public:
/** /// Construct empty tuple and add value types later.
** The preferred handle
**/
using ptr = std::shared_ptr<ValueType>;
/**
** Construct empty tuple and add value types later.
**/
TupleType() {} TupleType() {}
/**
** /param element_types A vector of types for the tuple elements
**/
TupleType(const std::vector<ValueType::ptr>& element_types)
: m_element_types(element_types)
{
}
const std::vector<ValueType::ptr> element_types() const { return m_element_types; }
std::vector<ValueType::ptr> element_types() { return m_element_types; }
virtual bool operator==(const ValueType::ptr& that) const override; /// @param element_types A vector of types for the tuple elements
TupleType(const std::vector<std::shared_ptr<ValueType>>& element_types)
protected: : m_element_types(element_types)
std::vector<ValueType::ptr> m_element_types;
};
/**
** Mixin for objects with type information
**/
class TypedValueMixin
{
public:
TypedValueMixin(const ValueType::ptr& type = nullptr)
: m_type(type)
{ {
} }
/** const std::vector<std::shared_ptr<ValueType>> get_element_types() const
** Set the type
** /param type The new type
**/
void type(const ValueType::ptr& type) { m_type = type; }
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void type(const element::Type& element_type, const Shape& shape)
{ {
m_type = std::make_shared<TensorViewType>(element_type, shape); return m_element_types;
} }
std::vector<std::shared_ptr<ValueType>> set_element_types() { return m_element_types; }
/** virtual bool operator==(const std::shared_ptr<ValueType>& that) const override;
** The type associated with this value.
**/
ValueType::ptr type() { return m_type; }
/**
** The type associated with this value.
**/
const ValueType::ptr type() const { return m_type; }
protected: protected:
ValueType::ptr m_type; std::vector<std::shared_ptr<ValueType>> m_element_types;
}; };
} }
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <list>
#include <fstream>
#include <cstdio> #include <cstdio>
#include <fstream>
#include <list>
#include "visualize.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "util.hpp" #include "util.hpp"
#include "visualize.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
...@@ -31,18 +31,17 @@ Visualize::Visualize(const string& name) ...@@ -31,18 +31,17 @@ Visualize::Visualize(const string& name)
void Visualize::add(node_ptr p) void Visualize::add(node_ptr p)
{ {
// map<size_t, list<node_ptr>> dependent_nodes; // map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(p, [&](node_ptr node) traverse_nodes(p, [&](node_ptr node) {
{ for (auto arg : node->get_arguments())
for (auto arg : node->arguments())
{ {
m_ss << " " << arg->node_id() << " -> " << node->node_id() << ";\n"; m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id() << ";\n";
} }
}); });
} }
void Visualize::save_dot(const string& path) const void Visualize::save_dot(const string& path) const
{ {
auto tmp_file = path+".tmp"; auto tmp_file = path + ".tmp";
ofstream out(tmp_file); ofstream out(tmp_file);
if (out) if (out)
{ {
......
...@@ -17,22 +17,20 @@ ...@@ -17,22 +17,20 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
/** /// @param tensor The tensor view to be broadcast.
** /param arg The tensor view to be broadcast. /// @param shape The shape of the result
** /param shape The shape of the result /// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast. /// the remaining axes in shape must be the same as the shape of arg.
** the remaining axes in shape must be the same as the shape of arg. std::shared_ptr<Node> ngraph::op::broadcast(const std::shared_ptr<Node>& tensor,
**/
Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape, const Shape& shape,
const BroadcastOp::Axes&& broadcast_axes) AxisSet&& broadcast_axes)
{ {
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes); return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
} }
void BroadcastOp::propagate_types() void BroadcastOp::propagate_types()
{ {
auto arg_type = m_arguments.at(0)->type(); auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type) if (nullptr == arg_type)
{ {
throw ngraph_error("Argument to broadcast is missing type."); throw ngraph_error("Argument to broadcast is missing type.");
...@@ -47,11 +45,11 @@ void BroadcastOp::propagate_types() ...@@ -47,11 +45,11 @@ void BroadcastOp::propagate_types()
{ {
target_shape.erase(target_shape.begin() + *i); target_shape.erase(target_shape.begin() + *i);
} }
if (Shape{target_shape} != arg_tensor_view_type->shape()) if (Shape{target_shape} != arg_tensor_view_type->get_shape())
{ {
throw ngraph_error("Broadcast arg, shape, and axes are incompatible"); throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
} }
// TODO If m_type is already set (by framework), this should verify that the type // TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects. // we expect is consistent with the type the framework expects.
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape); m_value_type = make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape);
} }
...@@ -19,12 +19,12 @@ ...@@ -19,12 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
void ConcatenateOp::propagate_types() void ConcatOp::propagate_types()
{ {
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
Node::ptr op::concatenate(const std::vector<Node::ptr>& args) std::shared_ptr<Node> op::concatenate(const std::vector<std::shared_ptr<Node>>& args)
{ {
return make_shared<ConcatenateOp>(args); return make_shared<ConcatOp>(args);
} }
...@@ -16,4 +16,4 @@ ...@@ -16,4 +16,4 @@
using namespace ngraph; using namespace ngraph;
void ScalarConstantBaseOp::propagate_types() {} void ScalarConstantBase::propagate_types() {}
...@@ -24,7 +24,8 @@ void ConvertOp::propagate_types() ...@@ -24,7 +24,8 @@ void ConvertOp::propagate_types()
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
shared_ptr<ConvertOp> op::convert(const Node::ptr& arg, const element::Type& element_type) shared_ptr<ConvertOp> op::convert(const std::shared_ptr<Node>& arg,
const element::Type& element_type)
{ {
return make_shared<ConvertOp>(arg, element_type); return make_shared<ConvertOp>(arg, element_type);
} }
...@@ -20,28 +20,31 @@ using namespace std; ...@@ -20,28 +20,31 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<DotOp>(arg0, arg1); return make_shared<DotOp>(arg0, arg1);
} }
void DotOp::propagate_types() void DotOp::propagate_types()
{ {
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->type()); auto arg0_tensor_type =
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->type()); dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type) if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{ {
throw ngraph_error("Arguments to dot must be tensor views"); throw ngraph_error("Arguments to dot must be tensor views");
} }
if (arg0_tensor_type->element_type() != arg1_tensor_type->element_type()) if (arg0_tensor_type->get_element_type() != arg1_tensor_type->get_element_type())
{ {
throw ngraph_error("Arguments to dot must have the same element type"); throw ngraph_error("Arguments to dot must have the same element type");
} }
// Use NumPy semantics for now // Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis. // Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector<size_t> arg0_shape = arg0_tensor_type->shape(); vector<size_t> arg0_shape = arg0_tensor_type->get_shape();
vector<size_t> arg1_shape = arg1_tensor_type->shape(); vector<size_t> arg1_shape = arg1_tensor_type->get_shape();
size_t arg0_reduction = arg0_shape.size() - 1; size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction; size_t arg1_reduction;
if (arg1_shape.size() > 1) if (arg1_shape.size() > 1)
...@@ -60,5 +63,5 @@ void DotOp::propagate_types() ...@@ -60,5 +63,5 @@ void DotOp::propagate_types()
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end()); copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end()); copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end()); copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end());
m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape); m_value_type = make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
} }
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
Function::Function(const Node::ptr& result, Function::Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<ngraph::Parameter>>& parameters) const std::vector<std::shared_ptr<ngraph::Parameter>>& parameters)
: m_result(result) : m_result(result)
, m_parameters(parameters) , m_parameters(parameters)
...@@ -30,13 +30,13 @@ Function::Function(const Node::ptr& result ...@@ -30,13 +30,13 @@ Function::Function(const Node::ptr& result
} }
} }
shared_ptr<Function> ngraph::op::function(const Node::ptr& result, shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result,
const initializer_list<shared_ptr<Parameter>>& parameters) const initializer_list<shared_ptr<Parameter>>& parameters)
{ {
return make_shared<Function>(result, parameters); return make_shared<Function>(result, parameters);
} }
shared_ptr<Function> ngraph::op::function(const Node::ptr& result, shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result,
const vector<shared_ptr<Parameter>>& parameters) const vector<shared_ptr<Parameter>>& parameters)
{ {
return make_shared<Function>(result, parameters); return make_shared<Function>(result, parameters);
......
...@@ -20,24 +20,26 @@ ...@@ -20,24 +20,26 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
std::string ngraph::Op::node_id() const std::string ngraph::Op::get_node_id() const
{ {
stringstream ss; stringstream ss;
ss << op_class_name() << "_" << m_instance_id; ss << get_op_class_name() << "_" << m_instance_id;
return ss.str(); return ss.str();
} }
Node::ptr ngraph::op::abs(const Node::ptr& arg) std::shared_ptr<Node> ngraph::op::abs(const std::shared_ptr<Node>& arg)
{ {
return make_shared<AbsOp>(arg); return make_shared<AbsOp>(arg);
} }
Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::add(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<AddOp>(arg0, arg1); return make_shared<AddOp>(arg0, arg1);
} }
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<CeilingOp>(arg0, arg1); return make_shared<CeilingOp>(arg0, arg1);
} }
...@@ -45,61 +47,68 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -45,61 +47,68 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
// 'convert', // 'convert',
// 'convolution', // 'convolution',
Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<DivideOp>(arg0, arg1); return make_shared<DivideOp>(arg0, arg1);
} }
Node::ptr ngraph::op::exponential(const Node::ptr& arg0) std::shared_ptr<Node> ngraph::op::exp(const std::shared_ptr<Node>& arg0)
{ {
return make_shared<ExponentialOp>(arg0); return make_shared<ExpOp>(arg0);
} }
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::floor(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<FloorOp>(arg0, arg1); return make_shared<FloorOp>(arg0, arg1);
} }
Node::ptr ngraph::op::log(const Node::ptr& arg0) std::shared_ptr<Node> ngraph::op::log(const std::shared_ptr<Node>& arg0)
{ {
return make_shared<LogOp>(arg0); return make_shared<LogOp>(arg0);
} }
Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::maximum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<MaximumOp>(arg0, arg1); return make_shared<MaximumOp>(arg0, arg1);
} }
Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::minimum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<MinimumOp>(arg0, arg1); return make_shared<MinimumOp>(arg0, arg1);
} }
Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::multiply(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<MultiplyOp>(arg0, arg1); return make_shared<MultiplyOp>(arg0, arg1);
} }
Node::ptr ngraph::op::negate(const Node::ptr& arg0) std::shared_ptr<Node> ngraph::op::negative(const std::shared_ptr<Node>& arg0)
{ {
return make_shared<NegateOp>(arg0); return make_shared<NegativeOp>(arg0);
} }
// 'pad', // 'pad',
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::power(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<PowerOp>(arg0, arg1); return make_shared<PowerOp>(arg0, arg1);
} }
//'reduce', //'reduce',
Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::remainder(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<RemainderOp>(arg0, arg1); return make_shared<RemainderOp>(arg0, arg1);
} }
Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape) std::shared_ptr<Node> ngraph::op::reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
{ {
return make_shared<ReshapeOp>(arg0, shape); return make_shared<ReshapeOp>(arg0, shape);
} }
...@@ -109,7 +118,8 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape) ...@@ -109,7 +118,8 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
// 'select', // 'select',
//'slice', //'slice',
Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<SubtractOp>(arg0, arg1); return make_shared<SubtractOp>(arg0, arg1);
} }
......
...@@ -19,13 +19,18 @@ ...@@ -19,13 +19,18 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
Parameter::Parameter(const ValueType::ptr& value_type) Parameter::Parameter(const std::shared_ptr<ValueType>& value_type)
: Node({}, value_type) : Node(value_type)
, m_function(nullptr) , m_function(nullptr)
, m_index(0) , m_index(0)
{ {
} }
Parameter::Parameter(const ngraph::element::Type element_type, const Shape& shape)
: Parameter(make_shared<TensorViewType>(element_type, shape))
{
}
void Parameter::assign_function(Function* function, size_t index) void Parameter::assign_function(Function* function, size_t index)
{ {
if (nullptr != m_function) if (nullptr != m_function)
...@@ -36,11 +41,9 @@ void Parameter::assign_function(Function* function, size_t index) ...@@ -36,11 +41,9 @@ void Parameter::assign_function(Function* function, size_t index)
m_index = index; m_index = index;
} }
void Parameter::propagate_types() void Parameter::propagate_types() {}
{
}
shared_ptr<Parameter> ngraph::op::parameter(const ValueType::ptr& value_type) shared_ptr<Parameter> ngraph::op::parameter(const std::shared_ptr<ValueType>& value_type)
{ {
return make_shared<Parameter>(value_type); return make_shared<Parameter>(value_type);
} }
...@@ -51,7 +54,7 @@ shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_ ...@@ -51,7 +54,7 @@ shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape)); return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
} }
std::string ngraph::Parameter::node_id() const std::string ngraph::Parameter::get_node_id() const
{ {
stringstream ss; stringstream ss;
ss << "parameter_" << m_instance_id; ss << "parameter_" << m_instance_id;
......
...@@ -24,7 +24,7 @@ void TupleOp::propagate_types() ...@@ -24,7 +24,7 @@ void TupleOp::propagate_types()
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
Node::ptr op::tuple(const std::vector<Node::ptr>& args) std::shared_ptr<Node> op::tuple(const std::vector<std::shared_ptr<Node>>& args)
{ {
return make_shared<TupleOp>(args); return make_shared<TupleOp>(args);
} }
...@@ -48,28 +48,3 @@ size_t ngraph::element::Type::size() const ...@@ -48,28 +48,3 @@ size_t ngraph::element::Type::size() const
{ {
return std::ceil((float)m_bitwidth / 8.0); return std::ceil((float)m_bitwidth / 8.0);
} }
namespace
{
const element::Float s_float32_t = element::Float{"float"};
const element::Int8 s_int8_t = element::Int8{"int8_t"};
const element::Int32 s_int32_t = element::Int32{"int32_t"};
const element::Int64 s_int64_t = element::Int64{"int64_t"};
const element::UInt8 s_uint8_t = element::UInt8{"uint8_t"};
const element::UInt32 s_uint32_t = element::UInt32{"uint32_t"};
const element::UInt64 s_uint64_t = element::UInt64{"uint64_t"};
}
template <>
const element::TraitedType<float>& element::TraitedType<float>::type = s_float32_t;
template <>
const element::TraitedType<int8_t>& element::TraitedType<int8_t>::type = s_int8_t;
template <>
const element::TraitedType<int32_t>& element::TraitedType<int32_t>::type = s_int32_t;
template <>
const element::TraitedType<int64_t>& element::TraitedType<int64_t>::type = s_int64_t;
template <>
const element::TraitedType<uint8_t>& element::TraitedType<uint8_t>::type = s_uint8_t;
template <>
const element::TraitedType<uint32_t>& element::TraitedType<uint32_t>::type = s_uint32_t;
template <>
const element::TraitedType<uint64_t>& element::TraitedType<uint64_t>::type = s_uint64_t;
\ No newline at end of file
...@@ -19,30 +19,30 @@ ...@@ -19,30 +19,30 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
bool TensorViewType::operator==(const ValueType::ptr& that) const bool TensorViewType::operator==(const std::shared_ptr<ValueType>& that) const
{ {
auto that_tvt = dynamic_pointer_cast<TensorViewType>(that); auto that_tvt = dynamic_pointer_cast<TensorViewType>(that);
if (nullptr == that_tvt) if (nullptr == that_tvt)
{ {
return false; return false;
} }
if (that_tvt->element_type() != m_element_type) if (that_tvt->get_element_type() != m_element_type)
{ {
return false; return false;
} }
if (that_tvt->shape() != m_shape) if (that_tvt->get_shape() != m_shape)
{ {
return false; return false;
} }
return true; return true;
} }
bool TupleType::operator==(const ValueType::ptr& that) const bool TupleType::operator==(const std::shared_ptr<ValueType>& that) const
{ {
auto that_tvt = dynamic_pointer_cast<TupleType>(that); auto that_tvt = dynamic_pointer_cast<TupleType>(that);
if (nullptr == that_tvt) if (nullptr == that_tvt)
{ {
return false; return false;
} }
return that_tvt->element_types() == element_types(); return that_tvt->get_element_types() == get_element_types();
} }
...@@ -26,8 +26,8 @@ void ngraph::dump(ostream& out, const void* _data, size_t _size) ...@@ -26,8 +26,8 @@ void ngraph::dump(ostream& out, const void* _data, size_t _size)
{ {
auto flags = out.flags(); auto flags = out.flags();
const uint8_t* data = reinterpret_cast<const uint8_t*>(_data); const uint8_t* data = reinterpret_cast<const uint8_t*>(_data);
int len = _size; size_t len = _size;
int index = 0; size_t index = 0;
while (index < len) while (index < len)
{ {
out << std::hex << std::setw(8) << std::setfill('0') << index; out << std::hex << std::setw(8) << std::setfill('0') << index;
...@@ -136,11 +136,11 @@ static void traverse_nodes(std::shared_ptr<ngraph::Node> p, ...@@ -136,11 +136,11 @@ static void traverse_nodes(std::shared_ptr<ngraph::Node> p,
std::set<size_t>& instances_seen) std::set<size_t>& instances_seen)
{ {
f(p); f(p);
for (auto arg : p->arguments()) for (auto arg : p->get_arguments())
{ {
if (instances_seen.find(arg->instance_id()) == instances_seen.end()) if (instances_seen.find(arg->get_instance_id()) == instances_seen.end())
{ {
instances_seen.insert(arg->instance_id()); instances_seen.insert(arg->get_instance_id());
traverse_nodes(arg, f, instances_seen); traverse_nodes(arg, f, instances_seen);
} }
} }
......
...@@ -16,39 +16,40 @@ ...@@ -16,39 +16,40 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include <memory>
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
TEST(build_graph, build_simple) TEST(build_graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto arg0 = op::parameter(element::Float::type, {7, 3}); auto arg0 = node<Parameter>(element::Float32::element_type(), Shape{7, 3});
auto arg1 = op::parameter(element::Float::type, {3}); auto arg1 = node<Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = op::parameter(element::Float::type, {32, 7}); auto arg2 = node<Parameter>(element::Float32::element_type(), Shape{32, 7});
auto arg3 = op::parameter(element::Float::type, {32, 7}); auto arg3 = node<Parameter>(element::Float32::element_type(), Shape{32, 7});
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0}); auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = op::dot(arg2, arg0); auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
ASSERT_EQ(2, dot->arguments().size()); auto dot = node<DotOp>(arg2, arg0);
ASSERT_EQ(dot->arguments()[0], arg2); ASSERT_EQ(dot->get_arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0); ASSERT_EQ(dot->get_arguments()[1], arg0);
auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3}); auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3});
ASSERT_EQ(cluster_0->result(), dot); ASSERT_EQ(cluster_0->get_result(), dot);
} }
// Check upcasting from ValueType. // Check upcasting from ValueType.
TEST(build_graph, as_type) TEST(build_graph, as_type)
{ {
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple. // Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType::ptr tv_vt = make_shared<TensorViewType>(element::Float::type, Shape{2, 3, 5}); auto tv_vt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt); auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
ASSERT_EQ(tv_vt, tv_tv); ASSERT_EQ(tv_vt, tv_tv);
auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt); auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt);
ASSERT_EQ(nullptr, tv_tp); ASSERT_EQ(nullptr, tv_tp);
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple. // Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
ValueType::ptr tp_vt = make_shared<TupleType>(vector<ValueType::ptr>{tv_vt, tv_vt}); auto tp_vt = make_shared<TupleType>(ValueTypes{tv_vt, tv_vt});
auto tp_tv = dynamic_pointer_cast<TensorViewType>(tp_vt); auto tp_tv = dynamic_pointer_cast<TensorViewType>(tp_vt);
ASSERT_EQ(nullptr, tp_tv); ASSERT_EQ(nullptr, tp_tv);
auto tp_tp = dynamic_pointer_cast<TupleType>(tp_vt); auto tp_tp = dynamic_pointer_cast<TupleType>(tp_vt);
...@@ -58,15 +59,15 @@ TEST(build_graph, as_type) ...@@ -58,15 +59,15 @@ TEST(build_graph, as_type)
// Check node comparisons // Check node comparisons
TEST(build_graph, node_comparison) TEST(build_graph, node_comparison)
{ {
auto arg0 = op::parameter(element::Float::type, {32, 3}); auto arg0 = node<Parameter>(element::Float32::element_type(), Shape{32, 3});
auto arg1 = op::parameter(element::Float::type, {3}); auto arg1 = node<Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = op::parameter(element::Float::type, {32}); auto arg2 = node<Parameter>(element::Float32::element_type(), Shape{32});
auto dot = op::dot(arg0, arg1); auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2); auto add = op::add(dot, arg2);
auto parg = op::parameter(element::Float::type, {}); auto parg = node<Parameter>(element::Float32::element_type(), Shape{});
auto pattern_dot = op::dot(parg, parg); auto pattern_dot = node<DotOp>(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot)); ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented. // TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong. // Need to figure out what's wrong.
...@@ -76,27 +77,26 @@ TEST(build_graph, node_comparison) ...@@ -76,27 +77,26 @@ TEST(build_graph, node_comparison)
TEST(build_graph, literal) TEST(build_graph, literal)
{ {
// float scalar from a float // float scalar from a float
auto float0 = FloatScalarConstantOp::make(3.0); //auto float0 = FloatScalarConstant::make(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::type, Shape{}); auto float0 = node<Float32ScalarConstant>(3.0);
ASSERT_EQ(float0->value(), 3.0); auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
ASSERT_EQ(*float0->type(), float_scalar_type); ASSERT_EQ(float0->get_value(), 3.0);
auto d = op::dot(float0, float0); ASSERT_EQ(*float0->get_value_type(), float_scalar_type);
ASSERT_EQ(d->arguments().at(0), float0); auto d = node<DotOp>(float0, float0);
ASSERT_EQ(d->arguments().at(1), float0); ASSERT_EQ(d->get_arguments().at(0), float0);
ASSERT_EQ(d->get_arguments().at(1), float0);
// float scalar from an int // float scalar from an int
auto float1 = FloatScalarConstantOp::make(3); auto float1 = node<Float32ScalarConstant>(3);
ASSERT_EQ(float1->value(), 3); ASSERT_EQ(float1->get_value(), 3);
ASSERT_EQ(*float1->type(), float_scalar_type); ASSERT_EQ(*float1->get_value_type(), float_scalar_type);
auto int32_0 = Int32ScalarConstantOp::make(3.0); auto int32_0 = node<Int32ScalarConstant>(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::type, Shape{}); auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->value(), 3); ASSERT_EQ(int32_0->get_value(), 3);
ASSERT_EQ(*int32_0->type(), int32_scalar_type); ASSERT_EQ(*int32_0->get_value_type(), int32_scalar_type);
ASSERT_NE(*int32_0->type(), float_scalar_type); ASSERT_NE(*int32_0->get_value_type(), float_scalar_type);
} }
// Check argument inverses // Check argument inverses
TEST(build_graph, arg_inverse) TEST(build_graph, arg_inverse) {}
{
}
...@@ -23,7 +23,7 @@ using namespace ngraph; ...@@ -23,7 +23,7 @@ using namespace ngraph;
TEST(op, is_op) TEST(op, is_op)
{ {
auto arg0 = op::parameter(element::Float::type, {1}); auto arg0 = op::parameter(element::Float32::element_type(), {1});
ASSERT_NE(nullptr, arg0); ASSERT_NE(nullptr, arg0);
EXPECT_TRUE(arg0->is_parameter()); EXPECT_TRUE(arg0->is_parameter());
EXPECT_FALSE(arg0->is_op()); EXPECT_FALSE(arg0->is_op());
...@@ -31,7 +31,7 @@ TEST(op, is_op) ...@@ -31,7 +31,7 @@ TEST(op, is_op)
TEST(op, is_parameter) TEST(op, is_parameter)
{ {
auto arg0 = op::parameter(element::Float::type, {1}); auto arg0 = op::parameter(element::Float32::element_type(), {1});
ASSERT_NE(nullptr, arg0); ASSERT_NE(nullptr, arg0);
auto t0 = op::add(arg0, arg0); auto t0 = op::add(arg0, arg0);
ASSERT_NE(nullptr, t0); ASSERT_NE(nullptr, t0);
......
...@@ -29,18 +29,17 @@ using namespace ngraph; ...@@ -29,18 +29,17 @@ using namespace ngraph;
static bool validate_list(const vector<Node*>& nodes) static bool validate_list(const vector<Node*>& nodes)
{ {
bool rc = true; bool rc = true;
for (auto it=nodes.rbegin(); it!=nodes.rend(); it++) for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
{ {
Node* node = *it;
auto node_tmp = *it; auto node_tmp = *it;
auto dependencies_tmp = node_tmp->arguments(); auto dependencies_tmp = node_tmp->get_arguments();
vector<Node*> dependencies; vector<Node*> dependencies;
for (shared_ptr<Node> n : dependencies_tmp) for (shared_ptr<Node> n : dependencies_tmp)
{ {
dependencies.push_back(n.get()); dependencies.push_back(n.get());
} }
auto tmp = it+1; auto tmp = it + 1;
for (; tmp!=nodes.rend(); tmp++) for (; tmp != nodes.rend(); tmp++)
{ {
auto dep_tmp = *tmp; auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp); auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
...@@ -60,9 +59,9 @@ static bool validate_list(const vector<Node*>& nodes) ...@@ -60,9 +59,9 @@ static bool validate_list(const vector<Node*>& nodes)
TEST(topological_sort, basic) TEST(topological_sort, basic)
{ {
vector<shared_ptr<Parameter>> args; vector<shared_ptr<Parameter>> args;
for (int i=0; i<10; i++) for (int i = 0; i < 10; i++)
{ {
auto arg = op::parameter(element::Float::type, {1}); auto arg = op::parameter(element::Float32::element_type(), {1});
ASSERT_NE(nullptr, arg); ASSERT_NE(nullptr, arg);
args.push_back(arg); args.push_back(arg);
} }
...@@ -79,13 +78,13 @@ TEST(topological_sort, basic) ...@@ -79,13 +78,13 @@ TEST(topological_sort, basic)
auto t4 = op::add(t2, args[5]); auto t4 = op::add(t2, args[5]);
ASSERT_NE(nullptr, t3); ASSERT_NE(nullptr, t3);
Node::ptr r0 = op::add(t3, t4); auto r0 = op::add(t3, t4);
ASSERT_NE(nullptr, r0); ASSERT_NE(nullptr, r0);
auto f0 = op::function(r0, args); auto f0 = op::function(r0, args);
ASSERT_NE(nullptr, f0); ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->arguments().size()); ASSERT_EQ(2, r0->get_arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0); auto op_r0 = static_pointer_cast<Op>(r0);
Visualize vz; Visualize vz;
......
...@@ -134,9 +134,7 @@ TEST(util, contains) ...@@ -134,9 +134,7 @@ TEST(util, contains)
EXPECT_FALSE(contains(v1, 8)); EXPECT_FALSE(contains(v1, 8));
} }
TEST(util, remove_from) TEST(util, remove_from) {}
{
}
TEST(util, reduce) TEST(util, reduce)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment