Commit 8b44c451 authored by Scott Cyphers's avatar Scott Cyphers

Review comments.

parent 4a70cad9
...@@ -21,9 +21,7 @@ ...@@ -21,9 +21,7 @@
namespace ngraph namespace ngraph
{ {
/** /// A user-defined function.
** A user-defined function.
**/
class Function class Function
{ {
public: public:
......
...@@ -27,11 +27,9 @@ namespace ngraph ...@@ -27,11 +27,9 @@ namespace ngraph
{ {
class Op; class Op;
/** /// Nodes are the backbone of the graph of Value dataflow. Every node has
** Nodes are the backbone of the graph of Value dataflow. Every node has /// zero or more nodes as arguments and one value, which is either a tensor
** zero or more nodes as arguments and one value, which is either a tensor /// view or a (possibly empty) tuple of values.
** view or a (possibly empty) tuple of values.
**/
class Node : public TypedValueMixin class Node : public TypedValueMixin
{ {
public: public:
...@@ -57,11 +55,9 @@ namespace ngraph ...@@ -57,11 +55,9 @@ namespace ngraph
virtual std::string get_node_id() const = 0; virtual std::string get_node_id() const = 0;
/** /// Return true if this has the same implementing class as node. This
** Return true if this has the same implementing class as node. This /// will be used by the pattern matcher when comparing a pattern
** will be used by the pattern matcher when comparing a pattern /// graph against the graph.
** graph against the graph.
**/
bool is_same_op_type(const Node::ptr& node) const bool is_same_op_type(const Node::ptr& node) const
{ {
return typeid(*this) == typeid(*node.get()); return typeid(*this) == typeid(*node.get());
......
...@@ -61,11 +61,9 @@ namespace ngraph ...@@ -61,11 +61,9 @@ namespace ngraph
//Node::ptr while(); //Node::ptr while();
} }
/** /// Op nodes are nodes whose value is the result of some operation
** Op nodes are nodes whose value is the result of some operation /// applied to its arguments. For calls to user functions, the op will
** applied to its arguments. For calls to user functions, the op will /// reference the user function.
** reference the user function.
**/
class Op : public Node class Op : public Node
{ {
public: public:
...@@ -78,10 +76,8 @@ namespace ngraph ...@@ -78,10 +76,8 @@ namespace ngraph
virtual std::string get_node_id() const override; virtual std::string get_node_id() const override;
}; };
/** /// A FunctionOp invokes a function on node arguments. In addition to the argument
** A FunctionOp invokes a function on node arguments. In addition to the argument /// we need to preserve the function.
** we need to preserve the function.
**/
class FunctionOp : public Op class FunctionOp : public Op
{ {
virtual std::string description() const override { return "FunctionOp"; } virtual std::string description() const override { return "FunctionOp"; }
...@@ -89,10 +85,8 @@ namespace ngraph ...@@ -89,10 +85,8 @@ namespace ngraph
Node::ptr m_function; Node::ptr m_function;
}; };
/** /// The is an operation we handle directly, i.e. all type checking, etc.
** The is an operation we handle directly, i.e. all type checking, etc. /// are defined in C++ rather than in terms of ngraph operations.
** are defined in C++ rather than in terms of ngraph operations.
**/
class BuiltinOp : public Op class BuiltinOp : public Op
{ {
public: public:
......
...@@ -19,12 +19,12 @@ namespace ngraph ...@@ -19,12 +19,12 @@ namespace ngraph
class BroadcastOp : public BuiltinOp class BroadcastOp : public BuiltinOp
{ {
public: public:
/** ///
** /param arg The tensor view to be broadcast. /// @param arg The tensor view to be broadcast.
** /param shape The shape of the result /// @param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast. /// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg. /// the remaining axes in shape must be the same as the shape of arg.
**/ ///
BroadcastOp(const Node::ptr& arg, const Shape& shape, const AxisSet& broadcast_axes) BroadcastOp(const Node::ptr& arg, const Shape& shape, const AxisSet& broadcast_axes)
: BuiltinOp({arg}) : BuiltinOp({arg})
, m_shape(shape) , m_shape(shape)
......
...@@ -21,11 +21,11 @@ namespace ngraph ...@@ -21,11 +21,11 @@ namespace ngraph
{ {
class Function; class Function;
/** ///
** Parameters are nodes that represent the arguments that will be passed to user-defined functions. /// Parameters are nodes that represent the arguments that will be passed to user-defined functions.
** Function creation requires a sequence of parameters. /// Function creation requires a sequence of parameters.
** Basic graph operations do not need parameters attached to a function. /// Basic graph operations do not need parameters attached to a function.
**/ ///
class Parameter : public Node class Parameter : public Node
{ {
friend class Function; friend class Function;
......
...@@ -24,9 +24,7 @@ namespace ngraph ...@@ -24,9 +24,7 @@ namespace ngraph
class Shape class Shape
{ {
public: public:
/** /// @param sizes A sequence of sizes.
** \param sizes A sequence of sizes.
**/
Shape(const std::initializer_list<size_t>& sizes) Shape(const std::initializer_list<size_t>& sizes)
: m_sizes(sizes) : m_sizes(sizes)
{ {
...@@ -37,9 +35,7 @@ namespace ngraph ...@@ -37,9 +35,7 @@ namespace ngraph
{ {
} }
/** /// Conversion to a vector of sizes.
** Conversion to a vector of sizes.
**/
operator const std::vector<size_t>&() const { return m_sizes; } operator const std::vector<size_t>&() const { return m_sizes; }
bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; } bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; }
bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; } bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; }
......
...@@ -25,11 +25,9 @@ namespace ngraph ...@@ -25,11 +25,9 @@ namespace ngraph
class TensorViewType; class TensorViewType;
class TupleType; class TupleType;
/** /// ValueType is
** ValueType is /// TensorViewType
** TensorViewType /// | TupleType(ValueType[])
** | TupleType(ValueType[])
**/
class ValueType class ValueType
{ {
public: public:
...@@ -43,21 +41,15 @@ namespace ngraph ...@@ -43,21 +41,15 @@ namespace ngraph
bool operator!=(const ValueType::ptr& that) const { return !(*this == that); } bool operator!=(const ValueType::ptr& that) const { return !(*this == that); }
}; };
/** /// Describes a tensor view; an element type and a shape.
** Describes a tensor view; an element type and a shape.
**/
class TensorViewType : public ValueType class TensorViewType : public ValueType
{ {
public: public:
/** // Preferred handle
** Preferred handle
**/
using ptr = std::shared_ptr<TensorViewType>; using ptr = std::shared_ptr<TensorViewType>;
/** /// /param element_type The type of the tensor elements.
** /param element_type The type of the tensor elements. /// /param shape The shape of the tensor.
** /param shape The shape of the tensor.
**/
TensorViewType(const element::Type& element_type, const Shape& shape) TensorViewType(const element::Type& element_type, const Shape& shape)
: m_element_type(element_type) : m_element_type(element_type)
, m_shape(shape) , m_shape(shape)
...@@ -74,24 +66,16 @@ namespace ngraph ...@@ -74,24 +66,16 @@ namespace ngraph
Shape m_shape; Shape m_shape;
}; };
/** /// Describes a tuple of values; a vector of types
** Describes a tuple of values; a vector of types
**/
class TupleType : public ValueType class TupleType : public ValueType
{ {
public: public:
/**
** The preferred handle
**/
using ptr = std::shared_ptr<ValueType>; using ptr = std::shared_ptr<ValueType>;
/** /// Construct empty tuple and add value types later.
** Construct empty tuple and add value types later.
**/
TupleType() {} TupleType() {}
/**
** /param element_types A vector of types for the tuple elements /// @param element_types A vector of types for the tuple elements
**/
TupleType(const std::vector<ValueType::ptr>& element_types) TupleType(const std::vector<ValueType::ptr>& element_types)
: m_element_types(element_types) : m_element_types(element_types)
{ {
......
...@@ -17,12 +17,10 @@ ...@@ -17,12 +17,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
/** /// @param tensor The tensor view to be broadcast.
** /param arg The tensor view to be broadcast. /// @param shape The shape of the result
** /param shape The shape of the result /// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast. /// the remaining axes in shape must be the same as the shape of arg.
** the remaining axes in shape must be the same as the shape of arg.
**/
Node::ptr ngraph::op::broadcast(const Node::ptr& tensor, Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape, const Shape& shape,
AxisSet&& broadcast_axes) AxisSet&& broadcast_axes)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment