Commit 125b1a85 authored by Scott Cyphers's avatar Scott Cyphers

Get tuples out of autodiff

parent 7ff8e1f2
...@@ -30,9 +30,6 @@ ...@@ -30,9 +30,6 @@
using namespace ngraph; using namespace ngraph;
/// @brief Make a zero matching a value type.
std::shared_ptr<Node> make_zero(const std::shared_ptr<const ValueType>& value_type);
std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& tensor_view_type) std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& tensor_view_type)
{ {
std::shared_ptr<Node> zero = std::shared_ptr<Node> zero =
...@@ -50,34 +47,6 @@ std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& ten ...@@ -50,34 +47,6 @@ std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& ten
return zero; return zero;
} }
std::shared_ptr<Node> make_zero(const std::shared_ptr<const TupleType>& tuple_type)
{
std::vector<std::shared_ptr<Node>> elements;
for (auto& value_type : tuple_type->get_element_types())
{
elements.push_back(make_zero(value_type));
}
return std::make_shared<op::Tuple>(elements);
}
std::shared_ptr<Node> make_zero(const std::shared_ptr<const ValueType>& value_type)
{
std::shared_ptr<const TensorViewType> tensor_view_type =
std::dynamic_pointer_cast<const TensorViewType>(value_type);
if (nullptr != tensor_view_type)
{
return (make_zero(tensor_view_type));
}
std::shared_ptr<const TupleType> tuple_type =
std::dynamic_pointer_cast<const TupleType>(value_type);
if (nullptr != tuple_type)
{
return make_zero(tuple_type);
}
// Should be impossible
throw ngraph_error("Unknown value type");
}
autodiff::Adjoints::Adjoints(const std::shared_ptr<Node>& y, const std::shared_ptr<Node>& c) autodiff::Adjoints::Adjoints(const std::shared_ptr<Node>& y, const std::shared_ptr<Node>& c)
{ {
// Pass 1 determines which nodes contribute to y as well as setting up a reverse // Pass 1 determines which nodes contribute to y as well as setting up a reverse
...@@ -143,7 +112,7 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x) ...@@ -143,7 +112,7 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x)
auto adjoint_it = m_adjoint_map.find(x.get()); auto adjoint_it = m_adjoint_map.find(x.get());
if (m_adjoint_map.end() == adjoint_it) if (m_adjoint_map.end() == adjoint_it)
{ {
auto result = make_zero(x->get_value_type()); auto result = make_zero(x->get_outputs().at(0).get_tensor_view_type());
adjoint_it = m_adjoint_map.insert({x.get(), result}).first; adjoint_it = m_adjoint_map.insert({x.get(), result}).first;
} }
return adjoint_it->second; return adjoint_it->second;
......
...@@ -47,6 +47,11 @@ namespace ngraph ...@@ -47,6 +47,11 @@ namespace ngraph
const std::set<Input*>& get_inputs() const { return m_inputs; } const std::set<Input*>& get_inputs() const { return m_inputs; }
const Tensor& get_tensor() const; const Tensor& get_tensor() const;
Tensor& get_tensor(); Tensor& get_tensor();
/// @return the tensor view type for the connected output
std::shared_ptr<const TensorViewType> get_tensor_view_type() const
{
return get_tensor_view()->get_tensor_view_type();
}
protected: protected:
Node* m_node; Node* m_node;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment