Commit 1c965079 authored by Scott Cyphers's avatar Scott Cyphers

Inverse arguments, missing file in last PR

parent 2bb6d51d
......@@ -15,6 +15,7 @@
#pragma once
#include <vector>
#include <set>
#include "ngraph/type.hpp"
......@@ -30,18 +31,27 @@ namespace ngraph
class Node : public TypedValueMixin
{
public:
using ptr = std::shared_ptr<Node>;
Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type = nullptr)
: TypedValueMixin(type)
, m_arguments(arguments)
{
// Add this node as a user of each argument.
for(auto node : m_arguments){
node->m_users.insert(node.get());
}
}
const std::vector<Node::ptr> arguments() const { return m_arguments; }
std::vector<Node::ptr> arguments() { return m_arguments; }
const std::multiset<Node*> users() const { return m_users; }
std::multiset<Node*> users() { return m_users; }
protected:
std::vector<Node::ptr> m_arguments;
std::multiset<Node*> m_users;
};
}
......@@ -31,12 +31,12 @@ TEST(graph, build_simple)
auto arg3 = cluster_0->parameter(3);
// call broadcast op on arg3, broadcasting on axis 1.
auto broadcast_1 = op::broadcast(arg3, 1);
auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0);
auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0);
// call dot op
auto dot = op::dot(arg2, arg0);
ASSERT_EQ(dot->dependents()[0], arg2);
ASSERT_EQ(dot->dependents()[1], arg0);
ASSERT_EQ(dot->arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0);
// Function returns tuple of dot and broadcast_1.
cluster_0->result()->value(dot);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment