Commit 51808ee1 authored by Scott Cyphers's avatar Scott Cyphers

dynamic_ptr_cast

parent 27ce1b0a
......@@ -39,13 +39,6 @@ namespace ngraph
using ptr = std::shared_ptr<ValueType>;
virtual ~ValueType() {}
/**
** Unmanaged cast to a supertype. dynamic_cast cannot be used
** directly on a shared_ptr. Can use dynamic_pointer_cast if
** a shared_ptr is needed.
**/
template<typename T> T as() { return dynamic_cast<T>(this); }
};
/**
......
......@@ -49,7 +49,7 @@ void BroadcastOp::propagate_types()
{
throw ngraph_error("Argument to broadcast is missing type.");
}
auto arg_tensor_view_type = arg_type->as<TensorViewType*>();
auto arg_tensor_view_type = dynamic_pointer_cast<TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to broadcast is not a tensor view");
......@@ -91,8 +91,8 @@ Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
void DotOp::propagate_types()
{
auto arg0_tensor_type = m_arguments.at(0)->type()->as<TensorViewType*>();
auto arg1_tensor_type = m_arguments.at(1)->type()->as<TensorViewType*>();
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->type());
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
......
......@@ -48,17 +48,17 @@ TEST(build_graph, as_type)
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType::ptr tv_vt = make_shared<TensorViewType>(element::float32_t, Shape{2, 3, 5});
TensorViewType* tv_tv = tv_vt->as<TensorViewType*>();
ASSERT_EQ(tv_vt.get(), tv_tv);
TupleType* tv_tp = tv_vt->as<TupleType*>();
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
ASSERT_EQ(tv_vt, tv_tv);
auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt);
ASSERT_EQ(nullptr, tv_tp);
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
ValueType::ptr tp_vt = make_shared<TupleType>(vector<ValueType::ptr>{tv_vt, tv_vt});
TensorViewType* tp_tv = tp_vt->as<TensorViewType*>();
auto tp_tv = dynamic_pointer_cast<TensorViewType>(tp_vt);
ASSERT_EQ(nullptr, tp_tv);
TupleType* tp_tp = tp_vt->as<TupleType*>();
ASSERT_EQ(tp_vt.get(), tp_tp);
auto tp_tp = dynamic_pointer_cast<TupleType>(tp_vt);
ASSERT_EQ(tp_vt, tp_tp);
}
// Check Call comparisons
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment