Commit b0c72a83 authored by Scott Cyphers's avatar Scott Cyphers

Add more doc.

parent b527258d
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/** /**
** Call nodes are nodes whose value is the result of some operation, the op, ** Call nodes are nodes whose value is the result of some operation, the op,
** applied to its arguments. We use the op as a callable to construct the ** applied to its arguments. We use the op as a callable to construct the
** call nodes. ** call nodes. For calls to user functions, the op will be the user function.
**/ **/
class Call : public Node class Call : public Node
{ {
...@@ -56,7 +56,9 @@ namespace ngraph ...@@ -56,7 +56,9 @@ namespace ngraph
}; };
/** /**
** There is exactly one instance of builtin op for each pre-defined operation. ** There is exactly one instance of builtin op for each pre-defined operation. These
** are intended to be used when matching calls in different graphs; every FooCall
** will have the same op.
**/ **/
class BuiltinOp : public Op class BuiltinOp : public Op
{ {
...@@ -124,6 +126,7 @@ namespace ngraph ...@@ -124,6 +126,7 @@ namespace ngraph
class DotCall : public BuiltinCall class DotCall : public BuiltinCall
{ {
public: public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotCall(const Node::ptr& arg0, const Node::ptr& arg1) DotCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall(s_op, {arg0, arg1})
{ {
......
...@@ -19,7 +19,13 @@ using namespace std; ...@@ -19,7 +19,13 @@ using namespace std;
std::shared_ptr<BuiltinOp> BroadcastCall::s_op = make_shared<BuiltinOp>("broadcast"); std::shared_ptr<BuiltinOp> BroadcastCall::s_op = make_shared<BuiltinOp>("broadcast");
shared_ptr<Node> ngraph::op::broadcast(const Node::ptr& tensor, /**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
shared_ptr<Node> ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape, const Shape& shape,
const vector<size_t>& broadcast_axes) const vector<size_t>& broadcast_axes)
{ {
...@@ -28,6 +34,7 @@ shared_ptr<Node> ngraph::op::broadcast(const Node::ptr& tensor, ...@@ -28,6 +34,7 @@ shared_ptr<Node> ngraph::op::broadcast(const Node::ptr& tensor,
std::shared_ptr<BuiltinOp> DotCall::s_op = make_shared<BuiltinOp>("dot"); std::shared_ptr<BuiltinOp> DotCall::s_op = make_shared<BuiltinOp>("dot");
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
shared_ptr<Node> ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1) shared_ptr<Node> ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<DotCall>(arg0, arg1); return make_shared<DotCall>(arg0, arg1);
......
...@@ -29,7 +29,7 @@ TEST(DISABLED_graph, build_simple) ...@@ -29,7 +29,7 @@ TEST(DISABLED_graph, build_simple)
cluster_0->parameter(2)->type(element::float32_t, {32, 7}); cluster_0->parameter(2)->type(element::float32_t, {32, 7});
cluster_0->parameter(3)->type(element::float32_t, {32, 7}); cluster_0->parameter(3)->type(element::float32_t, {32, 7});
auto arg3 = cluster_0->parameter(3); auto arg3 = cluster_0->parameter(3);
// call broadcast op on arg3, broadcasting on axis 1. // call broadcast op on arg3, broadcasting on axis 0.
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0}); auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0});
auto arg2 = cluster_0->parameter(2); auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0); auto arg0 = cluster_0->parameter(0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment