Unverified Commit d87b0065 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Remove TupleType, ValueType (#411)

* Remove TupleType, ValueType

* Fix compile error.
parent f6c6daef
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
shared_ptr<const ngraph::ValueType> descriptor::TensorView::get_value_type() const shared_ptr<const ngraph::TensorViewType> descriptor::TensorView::get_value_type() const
{ {
return m_tensor_view_type; return m_tensor_view_type;
} }
...@@ -51,7 +51,7 @@ namespace ngraph ...@@ -51,7 +51,7 @@ namespace ngraph
virtual const Tensor& get_tensor() const = 0; virtual const Tensor& get_tensor() const = 0;
virtual Tensor& get_tensor() = 0; virtual Tensor& get_tensor() = 0;
virtual std::shared_ptr<const ValueType> get_value_type() const; virtual std::shared_ptr<const TensorViewType> get_value_type() const;
const std::string& get_name() const { return m_name; } const std::string& get_name() const { return m_name; }
std::shared_ptr<const TensorViewType> get_tensor_view_type() const std::shared_ptr<const TensorViewType> get_tensor_view_type() const
......
...@@ -70,7 +70,7 @@ void Node::add_output(const element::Type& element_type, const Shape& shape) ...@@ -70,7 +70,7 @@ void Node::add_output(const element::Type& element_type, const Shape& shape)
m_outputs.emplace_back(this, i, tensor_view_descriptor); m_outputs.emplace_back(this, i, tensor_view_descriptor);
} }
void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type) void Node::set_value_type_checked(const shared_ptr<const TensorViewType>& value_type)
{ {
set_value_type_checked(value_type->get_element_type(), value_type->get_shape()); set_value_type_checked(value_type->get_element_type(), value_type->get_shape());
} }
......
...@@ -85,7 +85,7 @@ namespace ngraph ...@@ -85,7 +85,7 @@ namespace ngraph
// value_type agrees with the value type that was set. // value_type agrees with the value type that was set.
// This is used when the framework specifies a value type for the value, and we // This is used when the framework specifies a value type for the value, and we
// independently compute what we thing the value type should be from the arguments. // independently compute what we thing the value type should be from the arguments.
void set_value_type_checked(const std::shared_ptr<const ValueType>& value_type); void set_value_type_checked(const std::shared_ptr<const TensorViewType>& value_type);
void set_value_type_checked(const element::Type& element_type, const Shape& shape); void set_value_type_checked(const element::Type& element_type, const Shape& shape);
bool is_parameter() const; bool is_parameter() const;
......
...@@ -22,16 +22,15 @@ ...@@ -22,16 +22,15 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
bool ValueType::operator!=(const ValueType& that) const bool TensorViewType::operator!=(const TensorViewType& that) const
{ {
return !(*this == that); return !(*this == that);
} }
bool TensorViewType::operator==(const ValueType& that) const bool TensorViewType::operator==(const TensorViewType& that) const
{ {
bool rc = true; bool rc = true;
auto that_tvt = dynamic_cast<const TensorViewType*>(&that); auto that_tvt = dynamic_cast<const TensorViewType*>(&that);
auto that_tt = dynamic_cast<const TupleType*>(&that);
if (that_tvt != nullptr) if (that_tvt != nullptr)
{ {
rc = true; rc = true;
...@@ -44,10 +43,6 @@ bool TensorViewType::operator==(const ValueType& that) const ...@@ -44,10 +43,6 @@ bool TensorViewType::operator==(const ValueType& that) const
rc = false; rc = false;
} }
} }
else if (that_tt != nullptr)
{
rc = *that_tt == *this;
}
return rc; return rc;
} }
...@@ -57,75 +52,8 @@ void TensorViewType::collect_tensor_views( ...@@ -57,75 +52,8 @@ void TensorViewType::collect_tensor_views(
views.push_back(shared_from_this()); views.push_back(shared_from_this());
} }
bool TupleType::operator==(const ValueType& that) const
{
auto that_tvt = dynamic_cast<const TupleType*>(&that);
if (that_tvt == nullptr)
{
return false;
}
vector<shared_ptr<const ValueType>> this_values = this->get_element_types();
vector<shared_ptr<const ValueType>> that_values = that_tvt->get_element_types();
bool rc = this_values.size() == that_values.size();
if (rc)
{
for (size_t i = 0; i < this_values.size(); i++)
{
rc &= this_values[i]->get_element_type() == that_values[i]->get_element_type();
}
}
return rc;
}
void TupleType::collect_tensor_views(
std::vector<std::shared_ptr<const TensorViewType>>& views) const
{
for (auto elt : m_element_types)
{
elt->collect_tensor_views(views);
}
}
const Shape& TupleType::get_shape() const
{
throw ngraph_error("get_shape() called on Tuple");
}
const element::Type& TupleType::get_element_type() const
{
throw ngraph_error("get_element_type() called on Tuple");
}
std::ostream& ngraph::operator<<(std::ostream& out, const ValueType& obj)
{
auto tvt = dynamic_cast<const TensorViewType*>(&obj);
auto tup = dynamic_cast<const TupleType*>(&obj);
if (tvt != nullptr)
{
out << *tvt;
}
else if (tup != nullptr)
{
out << *tup;
}
else
{
out << "ValueType()";
}
return out;
}
std::ostream& ngraph::operator<<(std::ostream& out, const TensorViewType& obj) std::ostream& ngraph::operator<<(std::ostream& out, const TensorViewType& obj)
{ {
out << "TensorViewType(" << obj.m_element_type << ", {" << join(obj.m_shape) << "})"; out << "TensorViewType(" << obj.m_element_type << ", {" << join(obj.m_shape) << "})";
return out; return out;
} }
std::ostream& ngraph::operator<<(std::ostream& out, const TupleType& obj)
{
out << "TupleType()";
return out;
}
...@@ -23,48 +23,28 @@ ...@@ -23,48 +23,28 @@
namespace ngraph namespace ngraph
{ {
class TensorViewType; class TensorViewType;
class TupleType;
/// ValueType is std::ostream& operator<<(std::ostream&, const TensorViewType&);
/// TensorViewType
/// | TupleType(ValueType[])
class ValueType
{
ValueType(const ValueType&) = delete;
ValueType& operator=(const ValueType&) = delete;
protected:
ValueType() {}
public:
virtual ~ValueType() {}
virtual bool operator==(const ValueType& that) const = 0;
bool operator!=(const ValueType& that) const;
/// Add tensor views in depth-first order.
virtual void collect_tensor_views(
std::vector<std::shared_ptr<const TensorViewType>>& views) const = 0;
virtual const Shape& get_shape() const = 0;
virtual const element::Type& get_element_type() const = 0;
friend std::ostream& operator<<(std::ostream&, const ValueType&);
};
/// Describes a tensor view; an element type and a shape. /// Describes a tensor view; an element type and a shape.
class TensorViewType : public ValueType, public std::enable_shared_from_this<TensorViewType> class TensorViewType : public std::enable_shared_from_this<TensorViewType>
{ {
TensorViewType& operator=(const ValueType&) = delete;
public: public:
/// /param element_type The type of the tensor elements. /// /param element_type The type of the tensor elements.
/// /param shape The shape of the tensor. /// /param shape The shape of the tensor.
TensorViewType(const element::Type& element_type, const Shape& shape) TensorViewType(const element::Type& element_type, const Shape& shape)
: ValueType() : m_element_type(element_type)
, m_element_type(element_type)
, m_shape(shape) , m_shape(shape)
{ {
} }
virtual const element::Type& get_element_type() const override { return m_element_type; } const element::Type& get_element_type() const { return m_element_type; }
virtual const Shape& get_shape() const override { return m_shape; } const Shape& get_shape() const { return m_shape; }
virtual bool operator==(const ValueType& that) const override; bool operator==(const TensorViewType& that) const;
virtual void collect_tensor_views( bool operator!=(const TensorViewType& that) const;
std::vector<std::shared_ptr<const TensorViewType>>& views) const override; void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const;
friend std::ostream& operator<<(std::ostream&, const TensorViewType&); friend std::ostream& operator<<(std::ostream&, const TensorViewType&);
...@@ -72,37 +52,4 @@ namespace ngraph ...@@ -72,37 +52,4 @@ namespace ngraph
const element::Type m_element_type; const element::Type m_element_type;
Shape m_shape; Shape m_shape;
}; };
/// Describes a tuple of values; a vector of types
class TupleType : public ValueType
{
public:
/// Construct empty tuple and add value types later.
TupleType() {}
/// @param element_types A vector of types for the tuple elements
TupleType(const std::vector<std::shared_ptr<const ValueType>>& element_types)
: m_element_types(element_types)
{
}
const std::vector<std::shared_ptr<const ValueType>> get_element_types() const
{
return m_element_types;
}
std::vector<std::shared_ptr<const ValueType>> set_element_types()
{
return m_element_types;
}
virtual const element::Type& get_element_type() const override;
virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views(
std::vector<std::shared_ptr<const TensorViewType>>& views) const override;
virtual const Shape& get_shape() const override;
friend std::ostream& operator<<(std::ostream&, const TupleType&);
protected:
std::vector<std::shared_ptr<const ValueType>> m_element_types;
};
} }
...@@ -39,24 +39,6 @@ TEST(build_graph, build_simple) ...@@ -39,24 +39,6 @@ TEST(build_graph, build_simple)
ASSERT_EQ(cluster_0->get_output_op(0), dot); ASSERT_EQ(cluster_0->get_output_op(0), dot);
} }
// Check upcasting from ValueType.
TEST(build_graph, as_type)
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
auto tv_vt = make_shared<TensorViewType>(element::f32, Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
ASSERT_EQ(tv_vt, tv_tv);
auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt);
ASSERT_EQ(nullptr, tv_tp);
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
auto tp_vt = make_shared<TupleType>(ValueTypes{tv_vt, tv_vt});
auto tp_tv = dynamic_pointer_cast<TensorViewType>(tp_vt);
ASSERT_EQ(nullptr, tp_tv);
auto tp_tp = dynamic_pointer_cast<TupleType>(tp_vt);
ASSERT_EQ(tp_vt, tp_tp);
}
// Check node comparisons // Check node comparisons
TEST(build_graph, node_comparison) TEST(build_graph, node_comparison)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment