Commit 966d2cc2 authored by Scott Cyphers's avatar Scott Cyphers

Add some comments.

parent 0c06b371
...@@ -21,6 +21,7 @@ namespace ngraph ...@@ -21,6 +21,7 @@ namespace ngraph
/// Base error for ngraph runtime errors. /// Base error for ngraph runtime errors.
struct ngraph_error : std::runtime_error struct ngraph_error : std::runtime_error
{ {
explicit ngraph_error(const std::string& what_arg) explicit ngraph_error(const std::string& what_arg)
: std::runtime_error(what_arg) : std::runtime_error(what_arg)
{ {
......
...@@ -23,13 +23,16 @@ namespace ngraph ...@@ -23,13 +23,16 @@ namespace ngraph
class Function; class Function;
/** /**
** One parameter of a function. Within the function's graph ** Parameters are nodes that represent the arguments that will be passed to user-defined functions.
** the parameter is a node that represents the argument in a call. ** Function creation requires a sequence of parameters.
**/ ** Basic graph operations do not need parameters attached to a function.
**/
class Parameter : public Node class Parameter : public Node
{ {
friend class Function; friend class Function;
protected: protected:
// Called by the Function constructor to associate this parameter with the function.
// It is an error to try to associate a parameter with more than one function.
void assign_function(Function* function, size_t index); void assign_function(Function* function, size_t index);
public: public:
...@@ -46,7 +49,9 @@ namespace ngraph ...@@ -46,7 +49,9 @@ namespace ngraph
namespace op namespace op
{ {
/// Factory for frameworks
std::shared_ptr<ngraph::Parameter> parameter(const ValueType::ptr& value_type=nullptr); std::shared_ptr<ngraph::Parameter> parameter(const ValueType::ptr& value_type=nullptr);
/// Convenience factory for tests
std::shared_ptr<ngraph::Parameter> parameter(const ngraph::element::Type element_type, const Shape& shape); std::shared_ptr<ngraph::Parameter> parameter(const ngraph::element::Type element_type, const Shape& shape);
} }
} }
...@@ -22,23 +22,17 @@ using namespace ngraph; ...@@ -22,23 +22,17 @@ using namespace ngraph;
TEST(build_graph, build_simple) TEST(build_graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto cluster_0 = make_shared<Function>(4); auto arg0 = op::parameter(element::float32_t, {7, 3});
cluster_0->result()->type(element::float32_t, {32, 3}); auto arg1 = op::parameter(element::float32_t, {3});
cluster_0->parameter(0)->type(element::float32_t, {7, 3}); auto arg2 = op::parameter(element::float32_t, {32, 7});
cluster_0->parameter(1)->type(element::float32_t, {3}); auto arg3 = op::parameter(element::float32_t, {32, 7});
cluster_0->parameter(2)->type(element::float32_t, {32, 7});
cluster_0->parameter(3)->type(element::float32_t, {32, 7});
auto arg3 = cluster_0->parameter(3);
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0}); auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0});
auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0);
auto dot = op::dot(arg2, arg0); auto dot = op::dot(arg2, arg0);
ASSERT_EQ(dot->arguments()[0], arg2); ASSERT_EQ(dot->arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0); ASSERT_EQ(dot->arguments()[1], arg0);
// Function returns tuple of dot and broadcast_1.
cluster_0->result()->value(dot);
ASSERT_EQ(cluster_0->result()->value(), dot); auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3});
ASSERT_EQ(cluster_0->result(), dot);
} }
// Check upcasting from ValueType. // Check upcasting from ValueType.
...@@ -62,20 +56,14 @@ TEST(build_graph, as_type) ...@@ -62,20 +56,14 @@ TEST(build_graph, as_type)
// Check node comparisons // Check node comparisons
TEST(build_graph, node_comparison) TEST(build_graph, node_comparison)
{ {
auto fun = make_shared<Function>(3); auto arg0 = op::parameter(element::float32_t, {32, 3});
fun->parameter(0)->type(element::float32_t, {32, 3}); auto arg1 = op::parameter(element::float32_t, {3});
fun->parameter(1)->type(element::float32_t, {3}); auto arg2 = op::parameter(element::float32_t, {32});
fun->parameter(2)->type(element::float32_t, {32});
auto arg0 = fun->parameter(0);
auto arg1 = fun->parameter(1);
auto arg2 = fun->parameter(2);
auto dot = op::dot(arg0, arg1); auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2); auto add = op::add(dot, arg2);
auto pattern = make_shared<Function>(1); auto parg = op::parameter(element::float32_t, {});
pattern->parameter(0)->type(element::float32_t, {});
auto parg = pattern->parameter(0);
auto pattern_dot = op::dot(parg, parg); auto pattern_dot = op::dot(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot)); ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented. // TODO This passes because typeid is not behaving as documented.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment