Commit 1cf260ec authored by Robert Kimball's avatar Robert Kimball

remove class Op and merge functionality with Node

parent 6a0ac42e
...@@ -74,27 +74,24 @@ void Node::assign_tensors() ...@@ -74,27 +74,24 @@ void Node::assign_tensors()
} }
} }
bool Node::is_op() const bool Node::is_parameter() const
{ {
return dynamic_cast<const Op*>(this) != nullptr; return dynamic_cast<const op::Parameter*>(this) != nullptr;
} }
bool Node::is_parameter() const std::string Node::get_node_id() const
{ {
return dynamic_cast<const op::Parameter*>(this) != nullptr; stringstream ss;
ss << description() << "_" << m_instance_id;
return ss.str();
} }
namespace ngraph namespace ngraph
{ {
ostream& operator<<(ostream& out, const Node& node) ostream& operator<<(ostream& out, const Node& node)
{ {
auto op_tmp = dynamic_cast<const Op*>(&node); auto parameter_tmp = dynamic_cast<const op::Parameter*>(&node);
auto parameter_tmp = dynamic_cast<const Op*>(&node); if (parameter_tmp)
if (op_tmp)
{
out << "Op(" << op_tmp->get_node_id() << ")";
}
else if (parameter_tmp)
{ {
out << "Parameter(" << parameter_tmp->get_node_id() << ")"; out << "Parameter(" << parameter_tmp->get_node_id() << ")";
} }
......
...@@ -25,8 +25,6 @@ ...@@ -25,8 +25,6 @@
namespace ngraph namespace ngraph
{ {
class Op;
namespace descriptor namespace descriptor
{ {
class Input; class Input;
...@@ -71,7 +69,7 @@ namespace ngraph ...@@ -71,7 +69,7 @@ namespace ngraph
std::string get_name() const { return m_name; } std::string get_name() const { return m_name; }
void set_name(const std::string& name) { m_name = name; } void set_name(const std::string& name) { m_name = name; }
virtual std::string get_node_id() const = 0; virtual std::string get_node_id() const;
/// Return true if this has the same implementing class as node. This /// Return true if this has the same implementing class as node. This
/// will be used by the pattern matcher when comparing a pattern /// will be used by the pattern matcher when comparing a pattern
...@@ -100,7 +98,6 @@ namespace ngraph ...@@ -100,7 +98,6 @@ namespace ngraph
// independently compute what we thing the value type should be from the arguments. // independently compute what we thing the value type should be from the arguments.
void set_value_type_checked(const std::shared_ptr<ValueType>& value_type); void set_value_type_checked(const std::shared_ptr<ValueType>& value_type);
bool is_op() const;
bool is_parameter() const; bool is_parameter() const;
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
......
...@@ -22,32 +22,12 @@ ...@@ -22,32 +22,12 @@
namespace ngraph namespace ngraph
{ {
/// Op nodes are nodes whose value is the result of some operation
/// applied to its arguments. For calls to user functions, the op will
/// reference the user function.
class Op : public Node
{
public:
Op(const std::vector<std::shared_ptr<Node>>& arguments)
: Node(arguments)
{
}
Op()
: Node()
{
}
virtual std::string get_op_class_name() const = 0;
virtual std::string get_node_id() const override;
};
// TODO: These class definitions are to be moved into separate files in the op directory // TODO: These class definitions are to be moved into separate files in the op directory
namespace op namespace op
{ {
/// A Function invokes a function on node arguments. In addition to the argument /// A Function invokes a function on node arguments. In addition to the argument
/// we need to preserve the function. /// we need to preserve the function.
class FunctionCall : public Op class FunctionCall : public Node
{ {
virtual std::string description() const override { return "FunctionCall"; } virtual std::string description() const override { return "FunctionCall"; }
...@@ -57,14 +37,14 @@ namespace ngraph ...@@ -57,14 +37,14 @@ namespace ngraph
/// The is an operation we handle directly, i.e. all type checking, etc. /// The is an operation we handle directly, i.e. all type checking, etc.
/// are defined in C++ rather than in terms of ngraph operations. /// are defined in C++ rather than in terms of ngraph operations.
class Builtin : public Op class Builtin : public Node
{ {
public: public:
virtual std::string description() const override { return "Builtin"; } virtual std::string description() const override { return "Builtin"; }
protected: protected:
Builtin(const std::vector<std::shared_ptr<Node>>& args) Builtin(const std::vector<std::shared_ptr<Node>>& args)
: Op(args) : Node(args)
{ {
} }
}; };
...@@ -88,7 +68,7 @@ namespace ngraph ...@@ -88,7 +68,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Reshape"; } virtual std::string description() const override { return "Reshape"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
Shape m_shape; Shape m_shape;
...@@ -147,7 +127,7 @@ namespace ngraph ...@@ -147,7 +127,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "BinaryElementwiseComparison"; } virtual std::string description() const override { return "BinaryElementwiseComparison"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
virtual const element::Type& propagate_element_types( virtual const element::Type& propagate_element_types(
const element::Type& arg0_element_type, const element::Type& arg0_element_type,
...@@ -163,7 +143,7 @@ namespace ngraph ...@@ -163,7 +143,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "BinaryElementwiseArithmetic"; } virtual std::string description() const override { return "BinaryElementwiseArithmetic"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
virtual const element::Type& propagate_element_types( virtual const element::Type& propagate_element_types(
const element::Type& arg0_element_type, const element::Type& arg0_element_type,
......
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Abs"; } virtual std::string description() const override { return "Abs"; }
}; };
} }
} }
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1) : BinaryElementwiseArithmetic(arg0, arg1)
{ {
} }
virtual std::string get_op_class_name() const override { return "Add"; } virtual std::string description() const override { return "Add"; }
}; };
} }
} }
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Broadcast"; } virtual std::string description() const override { return "Broadcast"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
......
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Ceiling"; } virtual std::string description() const override { return "Ceiling"; }
}; };
} }
} }
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Concatenate"; } virtual std::string description() const override { return "Concatenate"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
} }
......
...@@ -27,7 +27,7 @@ namespace ngraph ...@@ -27,7 +27,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Convert"; } virtual std::string description() const override { return "Convert"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
......
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Divide"; } virtual std::string description() const override { return "Divide"; }
}; };
} }
} }
...@@ -45,7 +45,7 @@ namespace ngraph ...@@ -45,7 +45,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Dot"; } virtual std::string description() const override { return "Dot"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
} }
......
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseComparison(arg0, arg1) : BinaryElementwiseComparison(arg0, arg1)
{ {
} }
virtual std::string get_op_class_name() const override { return "Equal"; } virtual std::string description() const override { return "Equal"; }
}; };
} }
} }
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Exp"; } virtual std::string description() const override { return "Exp"; }
}; };
} }
} }
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Floor"; } virtual std::string description() const override { return "Floor"; }
}; };
} }
} }
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseComparison(arg0, arg1) : BinaryElementwiseComparison(arg0, arg1)
{ {
} }
virtual std::string get_op_class_name() const override { return "Greater"; } virtual std::string description() const override { return "Greater"; }
}; };
} }
} }
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseComparison(arg0, arg1) : BinaryElementwiseComparison(arg0, arg1)
{ {
} }
virtual std::string get_op_class_name() const override { return "Less"; } virtual std::string description() const override { return "Less"; }
}; };
} }
} }
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Log"; } virtual std::string description() const override { return "Log"; }
}; };
} }
} }
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1) : BinaryElementwiseArithmetic(arg0, arg1)
{ {
} }
virtual std::string get_op_class_name() const override { return "Maximum"; } virtual std::string description() const override { return "Maximum"; }
}; };
} }
} }
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1) : BinaryElementwiseArithmetic(arg0, arg1)
{ {
} }
virtual std::string get_op_class_name() const override { return "Minimum"; } virtual std::string description() const override { return "Minimum"; }
}; };
} }
} }
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Multiply"; } virtual std::string description() const override { return "Multiply"; }
}; };
} }
} }
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Negative"; } virtual std::string description() const override { return "Negative"; }
}; };
} }
} }
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1) : BinaryElementwiseArithmetic(arg0, arg1)
{ {
} }
virtual std::string get_op_class_name() const override { return "Power"; } virtual std::string description() const override { return "Power"; }
}; };
} }
} }
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1) : BinaryElementwiseArithmetic(arg0, arg1)
{ {
} }
virtual std::string get_op_class_name() const override { return "Remainder"; } virtual std::string description() const override { return "Remainder"; }
}; };
} }
} }
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Subtract"; } virtual std::string description() const override { return "Subtract"; }
}; };
} }
} }
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
} }
virtual std::string get_op_class_name() const override { return "Tuple"; } virtual std::string description() const override { return "Tuple"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
} }
......
...@@ -60,13 +60,9 @@ std::string Visualize::get_attributes(const Node* node) ...@@ -60,13 +60,9 @@ std::string Visualize::get_attributes(const Node* node)
{ {
ss << " " << node->get_node_id() << " [shape=box color=blue]\n"; ss << " " << node->get_node_id() << " [shape=box color=blue]\n";
} }
else if (node->is_op())
{
ss << " " << node->get_node_id() << " [shape=ellipse color=black]\n";
}
else else
{ {
ss << " " << node->get_node_id() << " [shape=diamond color=red]\n"; ss << " " << node->get_node_id() << " [shape=ellipse color=black]\n";
} }
return ss.str(); return ss.str();
} }
......
...@@ -19,10 +19,3 @@ ...@@ -19,10 +19,3 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
std::string ngraph::Op::get_node_id() const
{
stringstream ss;
ss << get_op_class_name() << "_" << m_instance_id;
return ss.str();
}
...@@ -26,7 +26,6 @@ TEST(op, is_op) ...@@ -26,7 +26,6 @@ TEST(op, is_op)
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1}); auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
ASSERT_NE(nullptr, arg0); ASSERT_NE(nullptr, arg0);
EXPECT_TRUE(arg0->is_parameter()); EXPECT_TRUE(arg0->is_parameter());
EXPECT_FALSE(arg0->is_op());
} }
TEST(op, is_parameter) TEST(op, is_parameter)
...@@ -36,5 +35,4 @@ TEST(op, is_parameter) ...@@ -36,5 +35,4 @@ TEST(op, is_parameter)
auto t0 = make_shared<op::Add>(arg0, arg0); auto t0 = make_shared<op::Add>(arg0, arg0);
ASSERT_NE(nullptr, t0); ASSERT_NE(nullptr, t0);
EXPECT_FALSE(t0->is_parameter()); EXPECT_FALSE(t0->is_parameter());
EXPECT_TRUE(t0->is_op());
} }
...@@ -107,17 +107,15 @@ TEST(benchmark, topological_sort) ...@@ -107,17 +107,15 @@ TEST(benchmark, topological_sort)
result = make_cell(result, in_1, in_2); result = make_cell(result, in_1, in_2);
} }
auto op_r0 = static_pointer_cast<Op>(result);
timer.start(); timer.start();
pass::TopologicalSort ts; pass::TopologicalSort ts;
ts.run_on_tree(op_r0); ts.run_on_tree(result);
auto sorted_list = ts.get_call_graph(); auto sorted_list = ts.get_call_graph();
timer.stop(); timer.stop();
INFO << "topological sort took " << timer.get_milliseconds() << "ms"; INFO << "topological sort took " << timer.get_milliseconds() << "ms";
size_t node_count = 0; size_t node_count = 0;
traverse_nodes(op_r0, [&](const Node* node) { traverse_nodes(result, [&](const Node* node) {
node_count++; node_count++;
}); });
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment