Commit 1cf260ec authored by Robert Kimball's avatar Robert Kimball

remove class Op and merge functionality with Node

parent 6a0ac42e
......@@ -74,27 +74,24 @@ void Node::assign_tensors()
}
}
bool Node::is_op() const
bool Node::is_parameter() const
{
return dynamic_cast<const Op*>(this) != nullptr;
return dynamic_cast<const op::Parameter*>(this) != nullptr;
}
bool Node::is_parameter() const
std::string Node::get_node_id() const
{
return dynamic_cast<const op::Parameter*>(this) != nullptr;
stringstream ss;
ss << description() << "_" << m_instance_id;
return ss.str();
}
namespace ngraph
{
ostream& operator<<(ostream& out, const Node& node)
{
auto op_tmp = dynamic_cast<const Op*>(&node);
auto parameter_tmp = dynamic_cast<const Op*>(&node);
if (op_tmp)
{
out << "Op(" << op_tmp->get_node_id() << ")";
}
else if (parameter_tmp)
auto parameter_tmp = dynamic_cast<const op::Parameter*>(&node);
if (parameter_tmp)
{
out << "Parameter(" << parameter_tmp->get_node_id() << ")";
}
......
......@@ -25,8 +25,6 @@
namespace ngraph
{
class Op;
namespace descriptor
{
class Input;
......@@ -71,7 +69,7 @@ namespace ngraph
std::string get_name() const { return m_name; }
void set_name(const std::string& name) { m_name = name; }
virtual std::string get_node_id() const = 0;
virtual std::string get_node_id() const;
/// Return true if this has the same implementing class as node. This
/// will be used by the pattern matcher when comparing a pattern
......@@ -100,7 +98,6 @@ namespace ngraph
// independently compute what we thing the value type should be from the arguments.
void set_value_type_checked(const std::shared_ptr<ValueType>& value_type);
bool is_op() const;
bool is_parameter() const;
size_t get_instance_id() const { return m_instance_id; }
......
......@@ -22,32 +22,12 @@
namespace ngraph
{
/// Op nodes are nodes whose value is the result of some operation
/// applied to its arguments. For calls to user functions, the op will
/// reference the user function.
class Op : public Node
{
public:
Op(const std::vector<std::shared_ptr<Node>>& arguments)
: Node(arguments)
{
}
Op()
: Node()
{
}
virtual std::string get_op_class_name() const = 0;
virtual std::string get_node_id() const override;
};
// TODO: These class definitions are to be moved into separate files in the op directory
namespace op
{
/// A Function invokes a function on node arguments. In addition to the argument
/// we need to preserve the function.
class FunctionCall : public Op
class FunctionCall : public Node
{
virtual std::string description() const override { return "FunctionCall"; }
......@@ -57,14 +37,14 @@ namespace ngraph
/// The is an operation we handle directly, i.e. all type checking, etc.
/// are defined in C++ rather than in terms of ngraph operations.
class Builtin : public Op
class Builtin : public Node
{
public:
virtual std::string description() const override { return "Builtin"; }
protected:
Builtin(const std::vector<std::shared_ptr<Node>>& args)
: Op(args)
: Node(args)
{
}
};
......@@ -88,7 +68,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Reshape"; }
virtual std::string description() const override { return "Reshape"; }
//virtual void propagate_types() override;
protected:
Shape m_shape;
......@@ -147,7 +127,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "BinaryElementwiseComparison"; }
virtual std::string description() const override { return "BinaryElementwiseComparison"; }
//virtual void propagate_types() override;
virtual const element::Type& propagate_element_types(
const element::Type& arg0_element_type,
......@@ -163,7 +143,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "BinaryElementwiseArithmetic"; }
virtual std::string description() const override { return "BinaryElementwiseArithmetic"; }
//virtual void propagate_types() override;
virtual const element::Type& propagate_element_types(
const element::Type& arg0_element_type,
......
......@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Abs"; }
virtual std::string description() const override { return "Abs"; }
};
}
}
......@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Add"; }
virtual std::string description() const override { return "Add"; }
};
}
}
......@@ -36,7 +36,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Broadcast"; }
virtual std::string description() const override { return "Broadcast"; }
virtual void propagate_types() override;
protected:
......
......@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Ceiling"; }
virtual std::string description() const override { return "Ceiling"; }
};
}
}
......@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Concatenate"; }
virtual std::string description() const override { return "Concatenate"; }
virtual void propagate_types() override;
};
}
......
......@@ -27,7 +27,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Convert"; }
virtual std::string description() const override { return "Convert"; }
virtual void propagate_types() override;
protected:
......
......@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Divide"; }
virtual std::string description() const override { return "Divide"; }
};
}
}
......@@ -45,7 +45,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Dot"; }
virtual std::string description() const override { return "Dot"; }
virtual void propagate_types() override;
};
}
......
......@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Equal"; }
virtual std::string description() const override { return "Equal"; }
};
}
}
......@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Exp"; }
virtual std::string description() const override { return "Exp"; }
};
}
}
......@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Floor"; }
virtual std::string description() const override { return "Floor"; }
};
}
}
......@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Greater"; }
virtual std::string description() const override { return "Greater"; }
};
}
}
......@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Less"; }
virtual std::string description() const override { return "Less"; }
};
}
}
......@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Log"; }
virtual std::string description() const override { return "Log"; }
};
}
}
......@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Maximum"; }
virtual std::string description() const override { return "Maximum"; }
};
}
}
......@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Minimum"; }
virtual std::string description() const override { return "Minimum"; }
};
}
}
......@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Multiply"; }
virtual std::string description() const override { return "Multiply"; }
};
}
}
......@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Negative"; }
virtual std::string description() const override { return "Negative"; }
};
}
}
......@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Power"; }
virtual std::string description() const override { return "Power"; }
};
}
}
......@@ -25,7 +25,7 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Remainder"; }
virtual std::string description() const override { return "Remainder"; }
};
}
}
......@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Subtract"; }
virtual std::string description() const override { return "Subtract"; }
};
}
}
......@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual std::string get_op_class_name() const override { return "Tuple"; }
virtual std::string description() const override { return "Tuple"; }
virtual void propagate_types() override;
};
}
......
......@@ -60,13 +60,9 @@ std::string Visualize::get_attributes(const Node* node)
{
ss << " " << node->get_node_id() << " [shape=box color=blue]\n";
}
else if (node->is_op())
{
ss << " " << node->get_node_id() << " [shape=ellipse color=black]\n";
}
else
{
ss << " " << node->get_node_id() << " [shape=diamond color=red]\n";
ss << " " << node->get_node_id() << " [shape=ellipse color=black]\n";
}
return ss.str();
}
......
......@@ -19,10 +19,3 @@
using namespace ngraph;
using namespace std;
std::string ngraph::Op::get_node_id() const
{
stringstream ss;
ss << get_op_class_name() << "_" << m_instance_id;
return ss.str();
}
......@@ -26,7 +26,6 @@ TEST(op, is_op)
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
ASSERT_NE(nullptr, arg0);
EXPECT_TRUE(arg0->is_parameter());
EXPECT_FALSE(arg0->is_op());
}
TEST(op, is_parameter)
......@@ -36,5 +35,4 @@ TEST(op, is_parameter)
auto t0 = make_shared<op::Add>(arg0, arg0);
ASSERT_NE(nullptr, t0);
EXPECT_FALSE(t0->is_parameter());
EXPECT_TRUE(t0->is_op());
}
......@@ -107,17 +107,15 @@ TEST(benchmark, topological_sort)
result = make_cell(result, in_1, in_2);
}
auto op_r0 = static_pointer_cast<Op>(result);
timer.start();
pass::TopologicalSort ts;
ts.run_on_tree(op_r0);
ts.run_on_tree(result);
auto sorted_list = ts.get_call_graph();
timer.stop();
INFO << "topological sort took " << timer.get_milliseconds() << "ms";
size_t node_count = 0;
traverse_nodes(op_r0, [&](const Node* node) {
traverse_nodes(result, [&](const Node* node) {
node_count++;
});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment