Commit 5f8bf07e authored by Scott Cyphers's avatar Scott Cyphers

Finish implementation of type casting.

parent 689c22d8
......@@ -39,8 +39,12 @@ namespace ngraph
using ptr = std::shared_ptr<ValueType>;
virtual ~ValueType() {}
virtual std::shared_ptr<TensorViewType> as_tensor_view_type() { return nullptr; }
virtual std::shared_ptr<TupleType> as_tuple_type() { return nullptr; }
/**
** Unmanaged cast to a supertype. dynamic_cast cannot be used
** directly on a shared_ptr.
**/
template<typename T> T as() { return dynamic_cast<T>(this); }
};
/**
......
......@@ -55,7 +55,7 @@ void BroadcastCall::propagate_types()
{
throw ngraph_error("Argument to broadcast is missing type.");
}
auto arg_tensor_view_type = arg_type->as_tensor_view_type();
auto arg_tensor_view_type = arg_type->as<TensorViewType*>();
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to broadcast is not a tensor view");
......@@ -103,8 +103,8 @@ Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
void DotCall::propagate_types()
{
auto arg0_tensor_type = m_arguments.at(0)->type()->as_tensor_view_type();
auto arg1_tensor_type = m_arguments.at(1)->type()->as_tensor_view_type();
auto arg0_tensor_type = m_arguments.at(0)->type()->as<TensorViewType*>();
auto arg1_tensor_type = m_arguments.at(1)->type()->as<TensorViewType*>();
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment