Commit 125b1a85 authored by Scott Cyphers's avatar Scott Cyphers

Get tuples out of autodiff

parent 7ff8e1f2
......@@ -30,9 +30,6 @@
using namespace ngraph;
/// @brief Make a zero matching a value type.
std::shared_ptr<Node> make_zero(const std::shared_ptr<const ValueType>& value_type);
std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& tensor_view_type)
{
std::shared_ptr<Node> zero =
......@@ -50,34 +47,6 @@ std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& ten
return zero;
}
std::shared_ptr<Node> make_zero(const std::shared_ptr<const TupleType>& tuple_type)
{
std::vector<std::shared_ptr<Node>> elements;
for (auto& value_type : tuple_type->get_element_types())
{
elements.push_back(make_zero(value_type));
}
return std::make_shared<op::Tuple>(elements);
}
std::shared_ptr<Node> make_zero(const std::shared_ptr<const ValueType>& value_type)
{
std::shared_ptr<const TensorViewType> tensor_view_type =
std::dynamic_pointer_cast<const TensorViewType>(value_type);
if (nullptr != tensor_view_type)
{
return (make_zero(tensor_view_type));
}
std::shared_ptr<const TupleType> tuple_type =
std::dynamic_pointer_cast<const TupleType>(value_type);
if (nullptr != tuple_type)
{
return make_zero(tuple_type);
}
// Should be impossible
throw ngraph_error("Unknown value type");
}
autodiff::Adjoints::Adjoints(const std::shared_ptr<Node>& y, const std::shared_ptr<Node>& c)
{
// Pass 1 determines which nodes contribute to y as well as setting up a reverse
......@@ -143,7 +112,7 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x)
auto adjoint_it = m_adjoint_map.find(x.get());
if (m_adjoint_map.end() == adjoint_it)
{
auto result = make_zero(x->get_value_type());
auto result = make_zero(x->get_outputs().at(0).get_tensor_view_type());
adjoint_it = m_adjoint_map.insert({x.get(), result}).first;
}
return adjoint_it->second;
......
......@@ -47,6 +47,11 @@ namespace ngraph
const std::set<Input*>& get_inputs() const { return m_inputs; }
const Tensor& get_tensor() const;
Tensor& get_tensor();
/// @return the tensor view type for the connected output
std::shared_ptr<const TensorViewType> get_tensor_view_type() const
{
return get_tensor_view()->get_tensor_view_type();
}
protected:
Node* m_node;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment