Commit db6e3052 authored by Scott Cyphers's avatar Scott Cyphers

Add test for type upcasting.

parent 1b026daa
...@@ -42,7 +42,8 @@ namespace ngraph ...@@ -42,7 +42,8 @@ namespace ngraph
/** /**
** Unmanaged cast to a supertype. dynamic_cast cannot be used ** Unmanaged cast to a supertype. dynamic_cast cannot be used
** directly on a shared_ptr. ** directly on a shared_ptr. Can use dynamic_pointer_cast if
** a shared_ptr is needed.
**/ **/
template<typename T> T as() { return dynamic_cast<T>(this); } template<typename T> T as() { return dynamic_cast<T>(this); }
}; };
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
TEST(ngraph, build_simple) TEST(build_graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto cluster_0 = make_shared<Function>(4); auto cluster_0 = make_shared<Function>(4);
...@@ -42,3 +42,21 @@ TEST(ngraph, build_simple) ...@@ -42,3 +42,21 @@ TEST(ngraph, build_simple)
ASSERT_EQ(cluster_0->result()->value(), dot); ASSERT_EQ(cluster_0->result()->value(), dot);
} }
// Check upcasting from ValueType.
TEST(build_graph, as_type)
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType::ptr tv_vt = make_shared<TensorViewType>(element::float32_t, Shape{2, 3, 5});
TensorViewType* tv_tv = tv_vt->as<TensorViewType*>();
ASSERT_EQ(tv_vt.get(), tv_tv);
TupleType* tv_tp = tv_vt->as<TupleType*>();
ASSERT_EQ(nullptr, tv_tp);
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
ValueType::ptr tp_vt = make_shared<TupleType>(vector<ValueType::ptr>{tv_vt, tv_vt});
TensorViewType* tp_tv = tp_vt->as<TensorViewType*>();
ASSERT_EQ(nullptr, tp_tv);
TupleType* tp_tp = tp_vt->as<TupleType*>();
ASSERT_EQ(tp_vt.get(), tp_tp);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment