Commit 1c965079 authored by Scott Cyphers's avatar Scott Cyphers

Inverse arguments, missing file in last PR

parent 2bb6d51d
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include <set>
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
...@@ -30,18 +31,27 @@ namespace ngraph ...@@ -30,18 +31,27 @@ namespace ngraph
class Node : public TypedValueMixin class Node : public TypedValueMixin
{ {
public: public:
using ptr = std::shared_ptr<Node>; using ptr = std::shared_ptr<Node>;
Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type = nullptr) Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type = nullptr)
: TypedValueMixin(type) : TypedValueMixin(type)
, m_arguments(arguments) , m_arguments(arguments)
{ {
// Add this node as a user of each argument.
for(auto node : m_arguments){
node->m_users.insert(node.get());
}
} }
const std::vector<Node::ptr> arguments() const { return m_arguments; } const std::vector<Node::ptr> arguments() const { return m_arguments; }
std::vector<Node::ptr> arguments() { return m_arguments; } std::vector<Node::ptr> arguments() { return m_arguments; }
const std::multiset<Node*> users() const { return m_users; }
std::multiset<Node*> users() { return m_users; }
protected: protected:
std::vector<Node::ptr> m_arguments; std::vector<Node::ptr> m_arguments;
std::multiset<Node*> m_users;
}; };
} }
...@@ -31,12 +31,12 @@ TEST(graph, build_simple) ...@@ -31,12 +31,12 @@ TEST(graph, build_simple)
auto arg3 = cluster_0->parameter(3); auto arg3 = cluster_0->parameter(3);
// call broadcast op on arg3, broadcasting on axis 1. // call broadcast op on arg3, broadcasting on axis 1.
auto broadcast_1 = op::broadcast(arg3, 1); auto broadcast_1 = op::broadcast(arg3, 1);
auto arg2 = cluster_0->parameter(2); auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0); auto arg0 = cluster_0->parameter(0);
// call dot op // call dot op
auto dot = op::dot(arg2, arg0); auto dot = op::dot(arg2, arg0);
ASSERT_EQ(dot->dependents()[0], arg2); ASSERT_EQ(dot->arguments()[0], arg2);
ASSERT_EQ(dot->dependents()[1], arg0); ASSERT_EQ(dot->arguments()[1], arg0);
// Function returns tuple of dot and broadcast_1. // Function returns tuple of dot and broadcast_1.
cluster_0->result()->value(dot); cluster_0->result()->value(dot);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment