Commit 973b3a0e authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #74 from NervanaSystems/cyphers/names

Cyphers/names
parents fac27c37 fd881acc
......@@ -39,6 +39,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-padded")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-potentially-evaluated-expression") # Triggers false alarms on typeid
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-weak-vtables") # Not ready for this yet
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-conversion")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-float-equal")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-duplicate-enum") # from numpy
......
......@@ -15,20 +15,33 @@
#pragma once
#include <memory>
#include <vector>
#include <set>
#include <vector>
// Names for types that aren't worth giving their own classes
namespace ngraph
{
class Node;
class Parameter;
class ValueType;
template <typename T, typename... A>
std::shared_ptr<T> node(A&&... args)
{
return std::make_shared<T>(args...);
}
/// Zero or more value types
using ValueTypes = std::vector<std::shared_ptr<ValueType>>;
/// Zero or more nodes
using Nodes = std::vector<std::shared_ptr<Node>>;
/// A set of indices, for example, reduction axes
using IndexSet = std::set<size_t>;
/// A sequence of axes
using AxisVector = std::vector<size_t>;
/// A set of axes, for example, reduction axes
using AxisSet = std::set<size_t>;
/// A list of parameters
using Parameters = std::vector<std::shared_ptr<Parameter>>;
......
......@@ -22,6 +22,8 @@
#include <string>
#include <type_traits>
#include "except.hpp"
namespace ngraph
{
namespace element
......@@ -41,42 +43,77 @@ namespace ngraph
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
private:
static std::map<std::string, Type> m_element_list;
size_t m_bitwidth;
bool m_is_float;
bool m_is_signed;
const std::string m_cname;
size_t m_bitwidth;
bool m_is_float;
bool m_is_signed;
const std::string& m_cname;
};
// Provides a compile-time name for a C++ type.
// Used in TraitedType for the string that supplies the C++ type name during code generation,
// so it needs to be a valid C++ name.
template <typename T>
const char* traited_type_name()
{
throw ngraph_error("Unknown type");
}
// Define a type string for a type T. Will make traited_type_name<T>() return "T"
#define NGRAPH_DEFINE_TRAITED_TYPE_NAME(T) \
template <> \
constexpr const char* traited_type_name<T>() \
{ \
return #T; \
}
// Literals (and probably other things we don't know about yet) need to have their C++ types
// and element types coordinated. Every element type corresponds to a TraitedType which provides
// access to both the instance and the C++ type used to hold the value during compilation.
template <typename T>
class TraitedType : public Type
{
public:
// This is the C++ type used to hold a value of this element type during compilation
using ctype = T;
// This is a reference to an instance of this element type.
static const TraitedType<T>& type;
TraitedType(const std::string& cname)
protected:
TraitedType()
: Type(sizeof(T) * 8,
std::is_floating_point<T>::value,
std::is_signed<T>::value,
cname)
traited_type_name<T>())
{
}
public:
// This is the C++ type used to hold a value of this element type during compilation
using type = T;
// This returns a reference to an instance of this element type.
static const TraitedType<T>& element_type()
{
static TraitedType<T> t;
return t;
}
};
// Human-readable names for the element types
using Float = TraitedType<float>;
using Int8 = TraitedType<int8_t>;
using Int32 = TraitedType<int32_t>;
using Int64 = TraitedType<int64_t>;
using UInt8 = TraitedType<uint8_t>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(float)
using Float32 = TraitedType<float>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(int8_t)
using Int8 = TraitedType<int8_t>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(int32_t)
using Int32 = TraitedType<int32_t>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(int64_t)
using Int64 = TraitedType<int64_t>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(uint8_t)
using UInt8 = TraitedType<uint8_t>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(uint32_t)
using UInt32 = TraitedType<uint32_t>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(uint64_t)
using UInt64 = TraitedType<uint64_t>;
}
}
......@@ -21,20 +21,22 @@
namespace ngraph
{
/**
** A user-defined function.
**/
/// A user-defined function.
class Function
{
public:
Function(const Node::ptr& result,
Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters);
Node::ptr result() { return m_result; }
Parameter::ptr parameter(size_t i) { return m_parameters[i]; }
std::string name() const { return m_name; }
std::shared_ptr<Node> get_result() { return m_result; }
const std::vector<std::shared_ptr<Parameter>> get_parameters() const
{
return m_parameters;
}
std::string get_name() const { return m_name; }
protected:
Node::ptr m_result;
std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters;
std::string m_name;
};
......@@ -42,10 +44,10 @@ namespace ngraph
namespace op
{
std::shared_ptr<Function>
function(const Node::ptr& result,
function(const std::shared_ptr<Node>& result,
const std::initializer_list<std::shared_ptr<Parameter>>& parameters);
std::shared_ptr<Function>
function(const Node::ptr& result,
function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters);
}
}
......@@ -17,9 +17,10 @@
size_t ngraph::Node::m_next_instance_id = 0;
ngraph::Node::Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type)
: TypedValueMixin(type)
, m_arguments(arguments)
ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments,
std::shared_ptr<ValueType> value_type)
: m_arguments(arguments)
, m_value_type(value_type)
, m_instance_id(m_next_instance_id++)
{
// Add this node as a user of each argument.
......@@ -47,15 +48,15 @@ namespace ngraph
auto parameter_tmp = dynamic_cast<const ngraph::Op*>(&node);
if (op_tmp)
{
out << "Op(" << op_tmp->node_id() << ")";
out << "Op(" << op_tmp->get_node_id() << ")";
}
else if (parameter_tmp)
{
out << "Parameter(" << parameter_tmp->node_id() << ")";
out << "Parameter(" << parameter_tmp->get_node_id() << ")";
}
else
{
out << "Node(" << node.node_id() << ")";
out << "Node(" << node.get_node_id() << ")";
}
return out;
}
......
......@@ -20,27 +20,32 @@
#include <iostream>
#include "type.hpp"
#include "common.hpp"
#include "type.hpp"
namespace ngraph
{
class Op;
/**
** Nodes are the backbone of the graph of Value dataflow. Every node has
** zero or more nodes as arguments and one value, which is either a tensor
** view or a (possibly empty) tuple of values.
**/
class Node : public TypedValueMixin, public std::enable_shared_from_this<Node>
/// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor
/// view or a (possibly empty) tuple of values.
class Node : public std::enable_shared_from_this<Node>
{
public:
using ptr = std::shared_ptr<Node>;
protected:
Node(const Nodes& arguments, ValueType::ptr type = nullptr);
Node(const Nodes& arguments, std::shared_ptr<ValueType> value_type = nullptr);
Node()
: Node({}, nullptr)
{
}
Node(std::shared_ptr<ValueType> value_type)
: Node({}, value_type)
{
}
virtual ~Node() {}
public:
/// A "one-liner" describing this node.
virtual std::string description() const = 0;
......@@ -48,38 +53,48 @@ namespace ngraph
/// Propagate types and check arguments for consistency
virtual void propagate_types() = 0;
const Nodes& arguments() const { return m_arguments; }
const Nodes& get_arguments() const { return m_arguments; }
const std::multiset<Node*>& users() const { return m_users; }
std::string name() const { return m_name; }
void name(const std::string& name) { m_name = name; }
std::string get_name() const { return m_name; }
void set_name(const std::string& name) { m_name = name; }
virtual std::string node_id() const = 0;
virtual std::string get_node_id() const = 0;
/**
** Return true if this has the same implementing class as node. This
** will be used by the pattern matcher when comparing a pattern
** graph against the graph.
**/
bool is_same_op_type(const Node::ptr& node) const
/// Return true if this has the same implementing class as node. This
/// will be used by the pattern matcher when comparing a pattern
/// graph against the graph.
bool is_same_op_type(const std::shared_ptr<Node>& node) const
{
return typeid(*this) == typeid(*node.get());
}
std::shared_ptr<ValueType> get_value_type() { return m_value_type; }
const std::shared_ptr<ValueType> get_value_type() const { return m_value_type; }
void set_value_type(const element::Type& element_type, const Shape& shape)
{
m_value_type = std::make_shared<TensorViewType>(element_type, shape);
}
void set_value_type(const std::shared_ptr<ValueType>& value_type)
{
m_value_type = value_type;
}
bool is_op() const;
bool is_parameter() const;
size_t instance_id() const { return m_instance_id; }
size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&);
protected:
Nodes m_arguments;
std::multiset<Node*> m_users;
std::string m_name;
size_t m_instance_id;
static size_t m_next_instance_id;
Nodes m_arguments;
std::shared_ptr<ValueType> m_value_type;
std::multiset<Node*> m_users;
std::string m_name;
size_t m_instance_id;
static size_t m_next_instance_id;
};
using node_ptr = std::shared_ptr<Node>;
}
......@@ -24,74 +24,87 @@ namespace ngraph
{
namespace op
{
Node::ptr abs(const Node::ptr& arg);
Node::ptr add(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr ceiling(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr convert();
//Node::ptr convolution();
Node::ptr divide(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr equal(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr exponential(const Node::ptr& arg0);
Node::ptr floor(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr get_tuple_element();
Node::ptr greater(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr greater_equal(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr less(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr less_equal(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr log(const Node::ptr& arg0);
//Node::ptr logical(); and, or, not
Node::ptr maximum(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr minimum(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr multiply(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr negate(const Node::ptr& arg0);
//Node::ptr pad();
Node::ptr power(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr reduce();
// Node::ptr reduce_window();
Node::ptr remainder(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr reshape(const Node::ptr& arg0, const Shape& shape);
//Node::ptr reverse();
//Node::ptr rng();
//Node::ptr select();
//Node::ptr select_scatter();
//Node::ptr slice();
Node::ptr subtract(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr transpose();
//Node::ptr while();
std::shared_ptr<Node> abs(const std::shared_ptr<Node>& arg);
std::shared_ptr<Node> add(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> ceiling(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> convert();
//std::shared_ptr<Node> convolution();
std::shared_ptr<Node> divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> equal(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> exp(const std::shared_ptr<Node>& arg0);
std::shared_ptr<Node> floor(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> get_tuple_element();
std::shared_ptr<Node> greater(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> greater_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> less(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> less_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> log(const std::shared_ptr<Node>& arg0);
//std::shared_ptr<Node> logical(); and, or, not
std::shared_ptr<Node> maximum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> minimum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> multiply(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> negative(const std::shared_ptr<Node>& arg0);
//std::shared_ptr<Node> pad();
std::shared_ptr<Node> power(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> reduce();
// std::shared_ptr<Node> reduce_window();
std::shared_ptr<Node> remainder(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> reshape(const std::shared_ptr<Node>& arg0, const Shape& shape);
//std::shared_ptr<Node> reverse();
//std::shared_ptr<Node> rng();
//std::shared_ptr<Node> select();
//std::shared_ptr<Node> select_scatter();
//std::shared_ptr<Node> slice();
std::shared_ptr<Node> subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> transpose();
//std::shared_ptr<Node> while();
}
/**
** Op nodes are nodes whose value is the result of some operation
** applied to its arguments. For calls to user functions, the op will
** reference the user function.
**/
/// Op nodes are nodes whose value is the result of some operation
/// applied to its arguments. For calls to user functions, the op will
/// reference the user function.
class Op : public Node
{
public:
Op(const std::vector<Node::ptr>& arguments)
: Node(arguments, nullptr)
Op(const std::vector<std::shared_ptr<Node>>& arguments)
: Node(arguments)
{
}
virtual std::string op_class_name() const = 0;
virtual std::string node_id() const;
Op()
: Node()
{
}
virtual std::string get_op_class_name() const = 0;
virtual std::string get_node_id() const override;
};
/**
** A FunctionOp invokes a function on node arguments. In addition to the argument
** we need to preserve the function.
**/
/// A FunctionOp invokes a function on node arguments. In addition to the argument
/// we need to preserve the function.
class FunctionOp : public Op
{
virtual std::string description() const override { return "FunctionOp"; }
protected:
Node::ptr m_function;
std::shared_ptr<Node> m_function;
};
/**
** The is an operation we handle directly, i.e. all type checking, etc.
** are defined in C++ rather than in terms of ngraph operations.
**/
/// The is an operation we handle directly, i.e. all type checking, etc.
/// are defined in C++ rather than in terms of ngraph operations.
class BuiltinOp : public Op
{
public:
......@@ -100,8 +113,9 @@ namespace ngraph
// TODO: Implement for each op. This enables graphs to be built for now.
virtual void propagate_types() override {}
protected:
BuiltinOp(const std::vector<Node::ptr>& args)
BuiltinOp(const std::vector<std::shared_ptr<Node>>& args)
: Op(args)
{
}
......@@ -110,204 +124,204 @@ namespace ngraph
class AbsOp : public BuiltinOp
{
public:
AbsOp(const Node::ptr& arg0)
AbsOp(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0})
{
}
virtual std::string op_class_name() const override { return "abs"; }
virtual std::string get_op_class_name() const override { return "abs"; }
//virtual void propagate_types() override;
};
class AddOp : public BuiltinOp
{
public:
AddOp(const Node::ptr& arg0, const Node::ptr& arg1)
AddOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "add"; }
virtual std::string get_op_class_name() const override { return "add"; }
//virtual void propagate_types() override;
};
class CeilingOp : public BuiltinOp
{
public:
CeilingOp(const Node::ptr& arg0, const Node::ptr& arg1)
CeilingOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "ceiling"; }
virtual std::string get_op_class_name() const override { return "ceiling"; }
//virtual void propagate_types() override;
};
class DivideOp : public BuiltinOp
{
public:
DivideOp(const Node::ptr& arg0, const Node::ptr& arg1)
DivideOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "divide"; }
virtual std::string get_op_class_name() const override { return "divide"; }
//virtual void propagate_types() override;
};
class EqualOp : public BuiltinOp
{
public:
EqualOp(const Node::ptr& arg0, const Node::ptr& arg1)
EqualOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "equal"; }
virtual std::string get_op_class_name() const override { return "equal"; }
//virtual void propagate_types() override;
};
class ExponentialOp : public BuiltinOp
class ExpOp : public BuiltinOp
{
public:
ExponentialOp(const Node::ptr& arg0)
ExpOp(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0})
{
}
virtual std::string op_class_name() const override { return "exp"; }
virtual std::string get_op_class_name() const override { return "exp"; }
//virtual void propagate_types() override;
};
class FloorOp : public BuiltinOp
{
public:
FloorOp(const Node::ptr& arg0, const Node::ptr& arg1)
FloorOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "floor"; }
virtual std::string get_op_class_name() const override { return "floor"; }
//virtual void propagate_types() override;
};
class GreaterOp : public BuiltinOp
{
public:
GreaterOp(const Node::ptr& arg0, const Node::ptr& arg1)
GreaterOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "greater"; }
virtual std::string get_op_class_name() const override { return "greater"; }
//virtual void propagate_types() override;
};
class LessOp : public BuiltinOp
{
public:
LessOp(const Node::ptr& arg0, const Node::ptr& arg1)
LessOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "less"; }
virtual std::string get_op_class_name() const override { return "less"; }
//virtual void propagate_types() override;
};
class LogOp : public BuiltinOp
{
public:
LogOp(const Node::ptr& arg0)
LogOp(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0})
{
}
virtual std::string op_class_name() const override { return "log"; }
virtual std::string get_op_class_name() const override { return "log"; }
//virtual void propagate_types() override;
};
class MaximumOp : public BuiltinOp
{
public:
MaximumOp(const Node::ptr& arg0, const Node::ptr& arg1)
MaximumOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "max"; }
virtual std::string get_op_class_name() const override { return "max"; }
//virtual void propagate_types() override;
};
class MinimumOp : public BuiltinOp
{
public:
MinimumOp(const Node::ptr& arg0, const Node::ptr& arg1)
MinimumOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "min"; }
virtual std::string get_op_class_name() const override { return "min"; }
//virtual void propagate_types() override;
};
class MultiplyOp : public BuiltinOp
{
public:
MultiplyOp(const Node::ptr& arg0, const Node::ptr& arg1)
MultiplyOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "multiply"; }
virtual std::string get_op_class_name() const override { return "multiply"; }
//virtual void propagate_types() override;
};
class NegateOp : public BuiltinOp
class NegativeOp : public BuiltinOp
{
public:
NegateOp(const Node::ptr& arg0)
NegativeOp(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0})
{
}
virtual std::string op_class_name() const override { return "negate"; }
virtual std::string get_op_class_name() const override { return "negative"; }
//virtual void propagate_types() override;
};
class PowerOp : public BuiltinOp
{
public:
PowerOp(const Node::ptr& arg0, const Node::ptr& arg1)
PowerOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "power"; }
virtual std::string get_op_class_name() const override { return "power"; }
//virtual void propagate_types() override;
};
class RemainderOp : public BuiltinOp
{
public:
RemainderOp(const Node::ptr& arg0, const Node::ptr& arg1)
RemainderOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "remainder"; }
virtual std::string get_op_class_name() const override { return "remainder"; }
//virtual void propagate_types() override;
};
class ReshapeOp : public BuiltinOp
{
public:
ReshapeOp(const Node::ptr& arg0, const Shape& shape)
ReshapeOp(const std::shared_ptr<Node>& arg0, const Shape& shape)
: BuiltinOp({arg0})
, m_shape(shape)
{
}
virtual std::string op_class_name() const override { return "reshape"; }
virtual std::string get_op_class_name() const override { return "reshape"; }
//virtual void propagate_types() override;
protected:
Shape m_shape;
......@@ -316,12 +330,12 @@ namespace ngraph
class SubtractOp : public BuiltinOp
{
public:
SubtractOp(const Node::ptr& arg0, const Node::ptr& arg1)
SubtractOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "subtract"; }
virtual std::string get_op_class_name() const override { return "subtract"; }
//virtual void propagate_types() override;
};
}
......@@ -19,33 +19,33 @@ namespace ngraph
class BroadcastOp : public BuiltinOp
{
public:
using Axes = std::vector<size_t>;
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
BroadcastOp(const Node::ptr& arg, const Shape& shape, const Axes& broadcast_axes)
///
/// @param arg The tensor view to be broadcast.
/// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
///
BroadcastOp(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: BuiltinOp({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
}
virtual std::string op_class_name() const override { return "broadcast"; }
virtual std::string get_op_class_name() const override { return "broadcast"; }
virtual void propagate_types() override;
protected:
Shape m_shape;
Axes m_broadcast_axes;
Shape m_shape;
AxisSet m_broadcast_axes;
};
namespace op
{
Node::ptr broadcast(const Node::ptr& tensor,
const Shape& shape,
const BroadcastOp::Axes&& broadcast_axes);
std::shared_ptr<Node> broadcast(const std::shared_ptr<Node>& tensor,
const Shape& shape,
AxisSet&& broadcast_axes);
}
}
......@@ -18,18 +18,18 @@ namespace ngraph
{
namespace op
{
Node::ptr concatenate(const Nodes& args);
std::shared_ptr<Node> concatenate(const Nodes& args);
}
class ConcatenateOp : public BuiltinOp
class ConcatOp : public BuiltinOp
{
public:
ConcatenateOp(const Nodes& args)
ConcatOp(const Nodes& args)
: BuiltinOp(args)
{
}
virtual std::string op_class_name() const override { return "concatenate"; }
virtual std::string get_op_class_name() const override { return "concatenate"; }
virtual void propagate_types() override;
};
}
......@@ -21,10 +21,10 @@
namespace ngraph
{
// Defines methods to all constant scalars
class ScalarConstantBaseOp : public Node
class ScalarConstantBase : public Node
{
protected:
ScalarConstantBaseOp(const std::shared_ptr<TensorViewType>& type)
ScalarConstantBase(const std::shared_ptr<TensorViewType>& type)
: Node({}, type)
{
}
......@@ -35,47 +35,39 @@ namespace ngraph
// Implement a constant scalar for each element type.
// The static make method takes a
template <typename T>
class ScalarConstantOp : public ScalarConstantBaseOp
class ScalarConstant : public ScalarConstantBase
{
public:
// The ngraph element type
using element_type = T;
// The C++ type that holds the element type
using ctype = typename T::ctype;
using type = typename T::type;
ScalarConstantOp(typename T::ctype value)
: ScalarConstantBaseOp(std::make_shared<TensorViewType>(T::type, Shape{}))
ScalarConstant(typename T::type value)
: ScalarConstantBase(std::make_shared<TensorViewType>(T::element_type(), Shape{}))
, m_value(value)
{
}
virtual std::string description() const override { return "ConstantScalar"; }
virtual std::string node_id() const override
virtual std::string description() const override { return "ScalarConstant"; }
virtual std::string get_node_id() const override
{
std::stringstream ss;
ss << description() << "_" << node_id();
ss << description() << "_" /* << node_id() */;
return ss.str();
}
typename T::ctype value() const { return m_value; }
// Make a constant from any value that can be converted to the C++ type we use
// to represent the values.
template <typename U>
static std::shared_ptr<ScalarConstantOp<T>> make(U value)
{
return std::make_shared<ScalarConstantOp<T>>(value);
}
typename T::type get_value() const { return m_value; }
protected:
typename T::ctype m_value;
typename T::type m_value;
};
using FloatScalarConstantOp = ScalarConstantOp<element::Float>;
using Int8ScalarConstantOp = ScalarConstantOp<element::Int8>;
using Int32ScalarConstantOp = ScalarConstantOp<element::Int32>;
using Int64ScalarConstantOp = ScalarConstantOp<element::Int64>;
using UInt8ScalarConstantOp = ScalarConstantOp<element::UInt8>;
using UInt32ScalarConstantOp = ScalarConstantOp<element::UInt32>;
using UInt64ScalarConstantOp = ScalarConstantOp<element::UInt64>;
using Float32ScalarConstant = ScalarConstant<element::Float32>;
using Int8ScalarConstant = ScalarConstant<element::Int8>;
using Int32ScalarConstant = ScalarConstant<element::Int32>;
using Int64ScalarConstant = ScalarConstant<element::Int64>;
using UInt8ScalarConstant = ScalarConstant<element::UInt8>;
using UInt32ScalarConstant = ScalarConstant<element::UInt32>;
using UInt64ScalarConstant = ScalarConstant<element::UInt64>;
}
......@@ -16,25 +16,25 @@
namespace ngraph
{
class ConvertOp : public BuiltinOp
{
public:
ConvertOp(const Node::ptr& arg, const ngraph::element::Type& element_type)
ConvertOp(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: BuiltinOp({arg})
, m_element_type(element_type)
{
}
virtual std::string op_class_name() const override { return "convert"; }
virtual std::string get_op_class_name() const override { return "convert"; }
virtual void propagate_types() override;
protected:
const ngraph::element::Type& m_element_type;
};
namespace op
{
std::shared_ptr<ngraph::ConvertOp> convert(const Node::ptr& arg, const ngraph::element::Type& element_type);
std::shared_ptr<ngraph::ConvertOp> convert(const std::shared_ptr<Node>& arg,
const ngraph::element::Type& element_type);
}
}
......@@ -20,17 +20,18 @@ namespace ngraph
{
public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotOp(const Node::ptr& arg0, const Node::ptr& arg1)
DotOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "dot"; }
virtual std::string get_op_class_name() const override { return "dot"; }
virtual void propagate_types() override;
};
namespace op
{
Node::ptr dot(const Node::ptr& arg0, const Node::ptr& arg1);
std::shared_ptr<Node> dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
}
}
......@@ -21,11 +21,11 @@ namespace ngraph
{
class Function;
/**
** Parameters are nodes that represent the arguments that will be passed to user-defined functions.
** Function creation requires a sequence of parameters.
** Basic graph operations do not need parameters attached to a function.
**/
///
/// Parameters are nodes that represent the arguments that will be passed to user-defined functions.
/// Function creation requires a sequence of parameters.
/// Basic graph operations do not need parameters attached to a function.
///
class Parameter : public Node
{
friend class Function;
......@@ -36,11 +36,12 @@ namespace ngraph
void assign_function(Function* function, size_t index);
public:
Parameter(const ValueType::ptr& value_type);
Parameter(const std::shared_ptr<ValueType>& value_type);
Parameter(const ngraph::element::Type element_type, const Shape& shape);
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
virtual std::string node_id() const override;
virtual std::string get_node_id() const override;
protected:
Function* m_function;
......@@ -50,9 +51,10 @@ namespace ngraph
namespace op
{
/// Factory for frameworks
std::shared_ptr<ngraph::Parameter> parameter(const ValueType::ptr& value_type = nullptr);
std::shared_ptr<ngraph::Parameter>
parameter(const std::shared_ptr<ValueType>& value_type = nullptr);
/// Convenience factory for tests
std::shared_ptr<ngraph::Parameter> parameter(const ngraph::element::Type element_type,
const Shape& shape);
std::shared_ptr<ngraph::Parameter> parameter(const element::Type element_type,
const Shape& shape);
}
}
......@@ -18,7 +18,7 @@ namespace ngraph
{
namespace op
{
Node::ptr tuple(const Nodes& args);
std::shared_ptr<Node> tuple(const Nodes& args);
}
class TupleOp : public BuiltinOp
......@@ -29,7 +29,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "tuple"; }
virtual std::string get_op_class_name() const override { return "tuple"; }
virtual void propagate_types() override;
};
}
......@@ -24,9 +24,7 @@ namespace ngraph
class Shape
{
public:
/**
** \param sizes A sequence of sizes.
**/
/// @param sizes A sequence of sizes.
Shape(const std::initializer_list<size_t>& sizes)
: m_sizes(sizes)
{
......@@ -37,12 +35,11 @@ namespace ngraph
{
}
/**
** Conversion to a vector of sizes.
**/
operator const std::vector<size_t>&() const { return m_sizes; }
/// Conversion to a vector of sizes.
operator const std::vector<size_t>&() const { return m_sizes; }
bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; }
bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; }
protected:
std::vector<size_t> m_sizes;
};
......
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "node.hpp"
#include "topological_sort.hpp"
#include "node.hpp"
#include "util.hpp"
using namespace ngraph;
......@@ -21,16 +21,16 @@ using namespace std;
void ngraph::TopologicalSort::promote_node(Node* n)
{
for (auto dn=m_dependent_nodes.begin(); dn!=m_dependent_nodes.end(); dn++)
for (auto dn = m_dependent_nodes.begin(); dn != m_dependent_nodes.end(); dn++)
{
if (dn->first > 0) // Skip zero as they should never be promoted
if (dn->first > 0) // Skip zero as they should never be promoted
{
auto it = find(dn->second.begin(), dn->second.end(), n);
if (it != dn->second.end())
{
// found the node
dn->second.erase(it);
m_dependent_nodes[dn->first-1].push_back(n);
m_dependent_nodes[dn->first - 1].push_back(n);
}
}
}
......@@ -38,9 +38,8 @@ void ngraph::TopologicalSort::promote_node(Node* n)
void ngraph::TopologicalSort::process(node_ptr p)
{
traverse_nodes(p, [&](node_ptr node)
{
list<Node*>& node_list = m_dependent_nodes[node->arguments().size()];
traverse_nodes(p, [&](node_ptr node) {
list<Node*>& node_list = m_dependent_nodes[node->get_arguments().size()];
node_list.push_back(node.get());
});
......
......@@ -14,9 +14,10 @@
#pragma once
#include <memory>
#include <map>
#include <list>
#include <map>
#include <memory>
#include <vector>
namespace ngraph
{
......@@ -30,7 +31,7 @@ class ngraph::TopologicalSort
public:
TopologicalSort() {}
void process(node_ptr);
void process(node_ptr);
const std::vector<Node*>& get_sorted_list() const;
private:
......
......@@ -25,123 +25,61 @@ namespace ngraph
class TensorViewType;
class TupleType;
/**
** ValueType is
** TensorViewType
** | TupleType(ValueType[])
**/
/// ValueType is
/// TensorViewType
/// | TupleType(ValueType[])
class ValueType
{
public:
/**
** Preferred handle
**/
using ptr = std::shared_ptr<ValueType>;
virtual ~ValueType() {}
virtual bool operator==(const ValueType::ptr& that) const = 0;
bool operator!=(const ValueType::ptr& that) const { return !(*this == that); }
virtual bool operator==(const std::shared_ptr<ValueType>& that) const = 0;
bool operator!=(const std::shared_ptr<ValueType>& that) const { return !(*this == that); }
};
/**
** Describes a tensor view; an element type and a shape.
**/
/// Describes a tensor view; an element type and a shape.
class TensorViewType : public ValueType
{
public:
/**
** Preferred handle
**/
using ptr = std::shared_ptr<TensorViewType>;
/**
** /param element_type The type of the tensor elements.
** /param shape The shape of the tensor.
**/
/// /param element_type The type of the tensor elements.
/// /param shape The shape of the tensor.
TensorViewType(const element::Type& element_type, const Shape& shape)
: m_element_type(element_type)
, m_shape(shape)
{
}
const element::Type& element_type() const { return m_element_type; }
const Shape& shape() const { return m_shape; }
const element::Type& get_element_type() const { return m_element_type; }
const Shape& get_shape() const { return m_shape; }
virtual bool operator==(const ValueType::ptr& that) const override;
virtual bool operator==(const std::shared_ptr<ValueType>& that) const override;
protected:
const element::Type& m_element_type;
Shape m_shape;
};
/**
** Describes a tuple of values; a vector of types
**/
/// Describes a tuple of values; a vector of types
class TupleType : public ValueType
{
public:
/**
** The preferred handle
**/
using ptr = std::shared_ptr<ValueType>;
/**
** Construct empty tuple and add value types later.
**/
/// Construct empty tuple and add value types later.
TupleType() {}
/**
** /param element_types A vector of types for the tuple elements
**/
TupleType(const std::vector<ValueType::ptr>& element_types)
: m_element_types(element_types)
{
}
const std::vector<ValueType::ptr> element_types() const { return m_element_types; }
std::vector<ValueType::ptr> element_types() { return m_element_types; }
virtual bool operator==(const ValueType::ptr& that) const override;
protected:
std::vector<ValueType::ptr> m_element_types;
};
/**
** Mixin for objects with type information
**/
class TypedValueMixin
{
public:
TypedValueMixin(const ValueType::ptr& type = nullptr)
: m_type(type)
/// @param element_types A vector of types for the tuple elements
TupleType(const std::vector<std::shared_ptr<ValueType>>& element_types)
: m_element_types(element_types)
{
}
/**
** Set the type
** /param type The new type
**/
void type(const ValueType::ptr& type) { m_type = type; }
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void type(const element::Type& element_type, const Shape& shape)
const std::vector<std::shared_ptr<ValueType>> get_element_types() const
{
m_type = std::make_shared<TensorViewType>(element_type, shape);
return m_element_types;
}
std::vector<std::shared_ptr<ValueType>> set_element_types() { return m_element_types; }
/**
** The type associated with this value.
**/
ValueType::ptr type() { return m_type; }
/**
** The type associated with this value.
**/
const ValueType::ptr type() const { return m_type; }
virtual bool operator==(const std::shared_ptr<ValueType>& that) const override;
protected:
ValueType::ptr m_type;
std::vector<std::shared_ptr<ValueType>> m_element_types;
};
}
......@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <list>
#include <fstream>
#include <cstdio>
#include <fstream>
#include <list>
#include "visualize.hpp"
#include "ngraph/node.hpp"
#include "util.hpp"
#include "visualize.hpp"
using namespace ngraph;
using namespace std;
......@@ -31,18 +31,17 @@ Visualize::Visualize(const string& name)
void Visualize::add(node_ptr p)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(p, [&](node_ptr node)
{
for (auto arg : node->arguments())
traverse_nodes(p, [&](node_ptr node) {
for (auto arg : node->get_arguments())
{
m_ss << " " << arg->node_id() << " -> " << node->node_id() << ";\n";
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id() << ";\n";
}
});
}
void Visualize::save_dot(const string& path) const
{
auto tmp_file = path+".tmp";
auto tmp_file = path + ".tmp";
ofstream out(tmp_file);
if (out)
{
......@@ -53,7 +52,7 @@ void Visualize::save_dot(const string& path) const
stringstream ss;
ss << "dot -Tpng " << tmp_file << " -o " << path;
auto cmd = ss.str();
auto cmd = ss.str();
auto stream = popen(cmd.c_str(), "r");
pclose(stream);
......
......@@ -17,22 +17,20 @@
using namespace std;
using namespace ngraph;
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape,
const BroadcastOp::Axes&& broadcast_axes)
/// @param tensor The tensor view to be broadcast.
/// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
std::shared_ptr<Node> ngraph::op::broadcast(const std::shared_ptr<Node>& tensor,
const Shape& shape,
AxisSet&& broadcast_axes)
{
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
}
void BroadcastOp::propagate_types()
{
auto arg_type = m_arguments.at(0)->type();
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to broadcast is missing type.");
......@@ -47,11 +45,11 @@ void BroadcastOp::propagate_types()
{
target_shape.erase(target_shape.begin() + *i);
}
if (Shape{target_shape} != arg_tensor_view_type->shape())
if (Shape{target_shape} != arg_tensor_view_type->get_shape())
{
throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
m_value_type = make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape);
}
......@@ -19,12 +19,12 @@
using namespace std;
using namespace ngraph;
void ConcatenateOp::propagate_types()
void ConcatOp::propagate_types()
{
throw ngraph_error("NIY");
}
Node::ptr op::concatenate(const std::vector<Node::ptr>& args)
std::shared_ptr<Node> op::concatenate(const std::vector<std::shared_ptr<Node>>& args)
{
return make_shared<ConcatenateOp>(args);
return make_shared<ConcatOp>(args);
}
......@@ -16,4 +16,4 @@
using namespace ngraph;
void ScalarConstantBaseOp::propagate_types() {}
void ScalarConstantBase::propagate_types() {}
......@@ -24,7 +24,8 @@ void ConvertOp::propagate_types()
throw ngraph_error("NIY");
}
shared_ptr<ConvertOp> op::convert(const Node::ptr& arg, const element::Type& element_type)
shared_ptr<ConvertOp> op::convert(const std::shared_ptr<Node>& arg,
const element::Type& element_type)
{
return make_shared<ConvertOp>(arg, element_type);
}
......@@ -20,28 +20,31 @@ using namespace std;
using namespace ngraph;
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
std::shared_ptr<Node> ngraph::op::dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<DotOp>(arg0, arg1);
}
void DotOp::propagate_types()
{
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->type());
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->type());
auto arg0_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
}
if (arg0_tensor_type->element_type() != arg1_tensor_type->element_type())
if (arg0_tensor_type->get_element_type() != arg1_tensor_type->get_element_type())
{
throw ngraph_error("Arguments to dot must have the same element type");
}
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector<size_t> arg0_shape = arg0_tensor_type->shape();
vector<size_t> arg1_shape = arg1_tensor_type->shape();
vector<size_t> arg0_shape = arg0_tensor_type->get_shape();
vector<size_t> arg1_shape = arg1_tensor_type->get_shape();
size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction;
if (arg1_shape.size() > 1)
......@@ -60,5 +63,5 @@ void DotOp::propagate_types()
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end());
m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
m_value_type = make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
}
......@@ -17,7 +17,7 @@
using namespace std;
using namespace ngraph;
Function::Function(const Node::ptr& result,
Function::Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<ngraph::Parameter>>& parameters)
: m_result(result)
, m_parameters(parameters)
......@@ -30,13 +30,13 @@ Function::Function(const Node::ptr& result
}
}
shared_ptr<Function> ngraph::op::function(const Node::ptr& result,
shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result,
const initializer_list<shared_ptr<Parameter>>& parameters)
{
return make_shared<Function>(result, parameters);
}
shared_ptr<Function> ngraph::op::function(const Node::ptr& result,
shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result,
const vector<shared_ptr<Parameter>>& parameters)
{
return make_shared<Function>(result, parameters);
......
......@@ -20,24 +20,26 @@
using namespace ngraph;
using namespace std;
std::string ngraph::Op::node_id() const
std::string ngraph::Op::get_node_id() const
{
stringstream ss;
ss << op_class_name() << "_" << m_instance_id;
ss << get_op_class_name() << "_" << m_instance_id;
return ss.str();
}
Node::ptr ngraph::op::abs(const Node::ptr& arg)
std::shared_ptr<Node> ngraph::op::abs(const std::shared_ptr<Node>& arg)
{
return make_shared<AbsOp>(arg);
}
Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1)
std::shared_ptr<Node> ngraph::op::add(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<AddOp>(arg0, arg1);
}
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<CeilingOp>(arg0, arg1);
}
......@@ -45,61 +47,68 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
// 'convert',
// 'convolution',
Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
std::shared_ptr<Node> ngraph::op::divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<DivideOp>(arg0, arg1);
}
Node::ptr ngraph::op::exponential(const Node::ptr& arg0)
std::shared_ptr<Node> ngraph::op::exp(const std::shared_ptr<Node>& arg0)
{
return make_shared<ExponentialOp>(arg0);
return make_shared<ExpOp>(arg0);
}
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1)
std::shared_ptr<Node> ngraph::op::floor(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<FloorOp>(arg0, arg1);
}
Node::ptr ngraph::op::log(const Node::ptr& arg0)
std::shared_ptr<Node> ngraph::op::log(const std::shared_ptr<Node>& arg0)
{
return make_shared<LogOp>(arg0);
}
Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1)
std::shared_ptr<Node> ngraph::op::maximum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MaximumOp>(arg0, arg1);
}
Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1)
std::shared_ptr<Node> ngraph::op::minimum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MinimumOp>(arg0, arg1);
}
Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1)
std::shared_ptr<Node> ngraph::op::multiply(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MultiplyOp>(arg0, arg1);
}
Node::ptr ngraph::op::negate(const Node::ptr& arg0)
std::shared_ptr<Node> ngraph::op::negative(const std::shared_ptr<Node>& arg0)
{
return make_shared<NegateOp>(arg0);
return make_shared<NegativeOp>(arg0);
}
// 'pad',
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
std::shared_ptr<Node> ngraph::op::power(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<PowerOp>(arg0, arg1);
}
//'reduce',
Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1)
std::shared_ptr<Node> ngraph::op::remainder(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<RemainderOp>(arg0, arg1);
}
Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
std::shared_ptr<Node> ngraph::op::reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
{
return make_shared<ReshapeOp>(arg0, shape);
}
......@@ -109,7 +118,8 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
// 'select',
//'slice',
Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1)
std::shared_ptr<Node> ngraph::op::subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<SubtractOp>(arg0, arg1);
}
......
......@@ -19,13 +19,18 @@
using namespace std;
using namespace ngraph;
Parameter::Parameter(const ValueType::ptr& value_type)
: Node({}, value_type)
Parameter::Parameter(const std::shared_ptr<ValueType>& value_type)
: Node(value_type)
, m_function(nullptr)
, m_index(0)
{
}
Parameter::Parameter(const ngraph::element::Type element_type, const Shape& shape)
: Parameter(make_shared<TensorViewType>(element_type, shape))
{
}
void Parameter::assign_function(Function* function, size_t index)
{
if (nullptr != m_function)
......@@ -36,11 +41,9 @@ void Parameter::assign_function(Function* function, size_t index)
m_index = index;
}
void Parameter::propagate_types()
{
}
void Parameter::propagate_types() {}
shared_ptr<Parameter> ngraph::op::parameter(const ValueType::ptr& value_type)
shared_ptr<Parameter> ngraph::op::parameter(const std::shared_ptr<ValueType>& value_type)
{
return make_shared<Parameter>(value_type);
}
......@@ -51,7 +54,7 @@ shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
}
std::string ngraph::Parameter::node_id() const
std::string ngraph::Parameter::get_node_id() const
{
stringstream ss;
ss << "parameter_" << m_instance_id;
......
......@@ -24,7 +24,7 @@ void TupleOp::propagate_types()
throw ngraph_error("NIY");
}
Node::ptr op::tuple(const std::vector<Node::ptr>& args)
std::shared_ptr<Node> op::tuple(const std::vector<std::shared_ptr<Node>>& args)
{
return make_shared<TupleOp>(args);
}
......@@ -48,28 +48,3 @@ size_t ngraph::element::Type::size() const
{
return std::ceil((float)m_bitwidth / 8.0);
}
namespace
{
const element::Float s_float32_t = element::Float{"float"};
const element::Int8 s_int8_t = element::Int8{"int8_t"};
const element::Int32 s_int32_t = element::Int32{"int32_t"};
const element::Int64 s_int64_t = element::Int64{"int64_t"};
const element::UInt8 s_uint8_t = element::UInt8{"uint8_t"};
const element::UInt32 s_uint32_t = element::UInt32{"uint32_t"};
const element::UInt64 s_uint64_t = element::UInt64{"uint64_t"};
}
template <>
const element::TraitedType<float>& element::TraitedType<float>::type = s_float32_t;
template <>
const element::TraitedType<int8_t>& element::TraitedType<int8_t>::type = s_int8_t;
template <>
const element::TraitedType<int32_t>& element::TraitedType<int32_t>::type = s_int32_t;
template <>
const element::TraitedType<int64_t>& element::TraitedType<int64_t>::type = s_int64_t;
template <>
const element::TraitedType<uint8_t>& element::TraitedType<uint8_t>::type = s_uint8_t;
template <>
const element::TraitedType<uint32_t>& element::TraitedType<uint32_t>::type = s_uint32_t;
template <>
const element::TraitedType<uint64_t>& element::TraitedType<uint64_t>::type = s_uint64_t;
\ No newline at end of file
......@@ -19,30 +19,30 @@
using namespace std;
using namespace ngraph;
bool TensorViewType::operator==(const ValueType::ptr& that) const
bool TensorViewType::operator==(const std::shared_ptr<ValueType>& that) const
{
auto that_tvt = dynamic_pointer_cast<TensorViewType>(that);
if (nullptr == that_tvt)
{
return false;
}
if (that_tvt->element_type() != m_element_type)
if (that_tvt->get_element_type() != m_element_type)
{
return false;
}
if (that_tvt->shape() != m_shape)
if (that_tvt->get_shape() != m_shape)
{
return false;
}
return true;
}
bool TupleType::operator==(const ValueType::ptr& that) const
bool TupleType::operator==(const std::shared_ptr<ValueType>& that) const
{
auto that_tvt = dynamic_pointer_cast<TupleType>(that);
if (nullptr == that_tvt)
{
return false;
}
return that_tvt->element_types() == element_types();
return that_tvt->get_element_types() == get_element_types();
}
......@@ -26,8 +26,8 @@ void ngraph::dump(ostream& out, const void* _data, size_t _size)
{
auto flags = out.flags();
const uint8_t* data = reinterpret_cast<const uint8_t*>(_data);
int len = _size;
int index = 0;
size_t len = _size;
size_t index = 0;
while (index < len)
{
out << std::hex << std::setw(8) << std::setfill('0') << index;
......@@ -136,11 +136,11 @@ static void traverse_nodes(std::shared_ptr<ngraph::Node> p,
std::set<size_t>& instances_seen)
{
f(p);
for (auto arg : p->arguments())
for (auto arg : p->get_arguments())
{
if (instances_seen.find(arg->instance_id()) == instances_seen.end())
if (instances_seen.find(arg->get_instance_id()) == instances_seen.end())
{
instances_seen.insert(arg->instance_id());
instances_seen.insert(arg->get_instance_id());
traverse_nodes(arg, f, instances_seen);
}
}
......
......@@ -16,40 +16,41 @@
#include "ngraph/ngraph.hpp"
#include <memory>
using namespace std;
using namespace ngraph;
TEST(build_graph, build_simple)
{
// Function with 4 parameters
auto arg0 = op::parameter(element::Float::type, {7, 3});
auto arg1 = op::parameter(element::Float::type, {3});
auto arg2 = op::parameter(element::Float::type, {32, 7});
auto arg3 = op::parameter(element::Float::type, {32, 7});
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0});
auto dot = op::dot(arg2, arg0);
ASSERT_EQ(2, dot->arguments().size());
ASSERT_EQ(dot->arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0);
auto arg0 = node<Parameter>(element::Float32::element_type(), Shape{7, 3});
auto arg1 = node<Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float32::element_type(), Shape{32, 7});
auto arg3 = node<Parameter>(element::Float32::element_type(), Shape{32, 7});
auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = node<DotOp>(arg2, arg0);
ASSERT_EQ(dot->get_arguments()[0], arg2);
ASSERT_EQ(dot->get_arguments()[1], arg0);
auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3});
ASSERT_EQ(cluster_0->result(), dot);
ASSERT_EQ(cluster_0->get_result(), dot);
}
// Check upcasting from ValueType.
TEST(build_graph, as_type)
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType::ptr tv_vt = make_shared<TensorViewType>(element::Float::type, Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
auto tv_vt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
ASSERT_EQ(tv_vt, tv_tv);
auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt);
ASSERT_EQ(nullptr, tv_tp);
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
ValueType::ptr tp_vt = make_shared<TupleType>(vector<ValueType::ptr>{tv_vt, tv_vt});
auto tp_tv = dynamic_pointer_cast<TensorViewType>(tp_vt);
auto tp_vt = make_shared<TupleType>(ValueTypes{tv_vt, tv_vt});
auto tp_tv = dynamic_pointer_cast<TensorViewType>(tp_vt);
ASSERT_EQ(nullptr, tp_tv);
auto tp_tp = dynamic_pointer_cast<TupleType>(tp_vt);
ASSERT_EQ(tp_vt, tp_tp);
......@@ -58,15 +59,15 @@ TEST(build_graph, as_type)
// Check node comparisons
TEST(build_graph, node_comparison)
{
auto arg0 = op::parameter(element::Float::type, {32, 3});
auto arg1 = op::parameter(element::Float::type, {3});
auto arg2 = op::parameter(element::Float::type, {32});
auto arg0 = node<Parameter>(element::Float32::element_type(), Shape{32, 3});
auto arg1 = node<Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float32::element_type(), Shape{32});
auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2);
auto parg = op::parameter(element::Float::type, {});
auto pattern_dot = op::dot(parg, parg);
auto parg = node<Parameter>(element::Float32::element_type(), Shape{});
auto pattern_dot = node<DotOp>(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong.
......@@ -76,27 +77,26 @@ TEST(build_graph, node_comparison)
TEST(build_graph, literal)
{
// float scalar from a float
auto float0 = FloatScalarConstantOp::make(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::type, Shape{});
ASSERT_EQ(float0->value(), 3.0);
ASSERT_EQ(*float0->type(), float_scalar_type);
auto d = op::dot(float0, float0);
ASSERT_EQ(d->arguments().at(0), float0);
ASSERT_EQ(d->arguments().at(1), float0);
//auto float0 = FloatScalarConstant::make(3.0);
auto float0 = node<Float32ScalarConstant>(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
ASSERT_EQ(float0->get_value(), 3.0);
ASSERT_EQ(*float0->get_value_type(), float_scalar_type);
auto d = node<DotOp>(float0, float0);
ASSERT_EQ(d->get_arguments().at(0), float0);
ASSERT_EQ(d->get_arguments().at(1), float0);
// float scalar from an int
auto float1 = FloatScalarConstantOp::make(3);
ASSERT_EQ(float1->value(), 3);
ASSERT_EQ(*float1->type(), float_scalar_type);
auto int32_0 = Int32ScalarConstantOp::make(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::type, Shape{});
ASSERT_EQ(int32_0->value(), 3);
ASSERT_EQ(*int32_0->type(), int32_scalar_type);
ASSERT_NE(*int32_0->type(), float_scalar_type);
auto float1 = node<Float32ScalarConstant>(3);
ASSERT_EQ(float1->get_value(), 3);
ASSERT_EQ(*float1->get_value_type(), float_scalar_type);
auto int32_0 = node<Int32ScalarConstant>(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->get_value(), 3);
ASSERT_EQ(*int32_0->get_value_type(), int32_scalar_type);
ASSERT_NE(*int32_0->get_value_type(), float_scalar_type);
}
// Check argument inverses
TEST(build_graph, arg_inverse)
{
}
TEST(build_graph, arg_inverse) {}
......@@ -23,7 +23,7 @@ using namespace ngraph;
TEST(op, is_op)
{
auto arg0 = op::parameter(element::Float::type, {1});
auto arg0 = op::parameter(element::Float32::element_type(), {1});
ASSERT_NE(nullptr, arg0);
EXPECT_TRUE(arg0->is_parameter());
EXPECT_FALSE(arg0->is_op());
......@@ -31,7 +31,7 @@ TEST(op, is_op)
TEST(op, is_parameter)
{
auto arg0 = op::parameter(element::Float::type, {1});
auto arg0 = op::parameter(element::Float32::element_type(), {1});
ASSERT_NE(nullptr, arg0);
auto t0 = op::add(arg0, arg0);
ASSERT_NE(nullptr, t0);
......
......@@ -29,21 +29,20 @@ using namespace ngraph;
static bool validate_list(const vector<Node*>& nodes)
{
bool rc = true;
for (auto it=nodes.rbegin(); it!=nodes.rend(); it++)
for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
{
Node* node = *it;
auto node_tmp = *it;
auto dependencies_tmp = node_tmp->arguments();
auto node_tmp = *it;
auto dependencies_tmp = node_tmp->get_arguments();
vector<Node*> dependencies;
for (shared_ptr<Node> n : dependencies_tmp)
{
dependencies.push_back(n.get());
}
auto tmp = it+1;
for (; tmp!=nodes.rend(); tmp++)
auto tmp = it + 1;
for (; tmp != nodes.rend(); tmp++)
{
auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
if (found != dependencies.end())
{
dependencies.erase(found);
......@@ -60,9 +59,9 @@ static bool validate_list(const vector<Node*>& nodes)
TEST(topological_sort, basic)
{
vector<shared_ptr<Parameter>> args;
for (int i=0; i<10; i++)
for (int i = 0; i < 10; i++)
{
auto arg = op::parameter(element::Float::type, {1});
auto arg = op::parameter(element::Float32::element_type(), {1});
ASSERT_NE(nullptr, arg);
args.push_back(arg);
}
......@@ -79,13 +78,13 @@ TEST(topological_sort, basic)
auto t4 = op::add(t2, args[5]);
ASSERT_NE(nullptr, t3);
Node::ptr r0 = op::add(t3, t4);
auto r0 = op::add(t3, t4);
ASSERT_NE(nullptr, r0);
auto f0 = op::function(r0, args);
ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->arguments().size());
ASSERT_EQ(2, r0->get_arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0);
Visualize vz;
......
......@@ -134,9 +134,7 @@ TEST(util, contains)
EXPECT_FALSE(contains(v1, 8));
}
TEST(util, remove_from)
{
}
TEST(util, remove_from) {}
TEST(util, reduce)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment