Commit 11ed222d authored by Scott Cyphers's avatar Scott Cyphers

Numeric/sym comparison works.

parent 420c1abe
...@@ -87,7 +87,6 @@ ngraph::runtime::FunctionSpec::operator std::shared_ptr<Function>() const ...@@ -87,7 +87,6 @@ ngraph::runtime::FunctionSpec::operator std::shared_ptr<Function>() const
return std::make_shared<ngraph::Function>(m_result, m_result_type, m_parameters); return std::make_shared<ngraph::Function>(m_result, m_result_type, m_parameters);
} }
// Returns (dy/(dXs))(C, Xs)
std::shared_ptr<ngraph::runtime::FunctionSpec> std::shared_ptr<ngraph::runtime::FunctionSpec>
ngraph::runtime::derivative(const std::shared_ptr<ngraph::runtime::FunctionSpec>& f) ngraph::runtime::derivative(const std::shared_ptr<ngraph::runtime::FunctionSpec>& f)
{ {
...@@ -121,7 +120,7 @@ std::vector<std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>> ...@@ -121,7 +120,7 @@ std::vector<std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>>
Shape y_shape = Shape y_shape =
std::dynamic_pointer_cast<const ngraph::TensorViewType>(y->get_value_type())->get_shape(); std::dynamic_pointer_cast<const ngraph::TensorViewType>(y->get_value_type())->get_shape();
// Check all the shapes // Results for each derivative, shape Y|X_i
std::vector<std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>> results; std::vector<std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>> results;
for (size_t i = 0; i < args.size(); i++) for (size_t i = 0; i < args.size(); i++)
{ {
...@@ -141,11 +140,11 @@ std::vector<std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>> ...@@ -141,11 +140,11 @@ std::vector<std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>>
args_tv.insert(args_tv.begin(), args.begin(), args.end()); args_tv.insert(args_tv.begin(), args.begin(), args.end());
cf->tensor_call(args_tv, TensorViewPtrs{ref_y}); cf->tensor_call(args_tv, TensorViewPtrs{ref_y});
auto ref_vec = ref_y->get_vector(); auto& ref_vec = ref_y->get_vector();
// inc_y will hold f(x+dx) values // inc_y will hold f(x+dx) values
auto inc_y = backend->make_parameterized_tensor_view<ET>(y_shape); auto inc_y = backend->make_parameterized_tensor_view<ET>(y_shape);
auto inc_vec = inc_y->get_vector(); auto& inc_vec = inc_y->get_vector();
// Assuming vars, y, and results are row-major // Assuming vars, y, and results are row-major
...@@ -155,8 +154,8 @@ std::vector<std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>> ...@@ -155,8 +154,8 @@ std::vector<std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>>
auto arg = args[i]; auto arg = args[i];
auto df_darg = results[i]; auto df_darg = results[i];
auto df_darg_it = df_darg->get_vector().begin(); auto df_darg_it = df_darg->get_vector().begin();
std::vector<typename ET::type>& vec = arg->get_vector(); auto& vec = arg->get_vector();
for (size_t j = 0; j < vec.size(); i++) for (size_t j = 0; j < vec.size(); j++)
{ {
auto old_val = vec[j]; auto old_val = vec[j];
vec[j] += delta; vec[j] += delta;
...@@ -247,7 +246,7 @@ std::vector<std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>> ...@@ -247,7 +246,7 @@ std::vector<std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>>
TensorViewPtrs bprops_tv; TensorViewPtrs bprops_tv;
bprops_tv.insert(bprops_tv.begin(), bprops.begin(), bprops.end()); bprops_tv.insert(bprops_tv.begin(), bprops.begin(), bprops.end());
auto c_vec = c_arg->get_vector(); auto& c_vec = c_arg->get_vector();
for (size_t i = 0; i < c_vec.size(); i++) for (size_t i = 0; i < c_vec.size(); i++)
{ {
c_vec[i] = 1; c_vec[i] = 1;
...@@ -255,9 +254,8 @@ std::vector<std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>> ...@@ -255,9 +254,8 @@ std::vector<std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>>
c_vec[i] = 0; c_vec[i] = 0;
for (size_t j = 0; j < results.size(); j++) for (size_t j = 0; j < results.size(); j++)
{ {
auto bprop_vec = bprops[j]->get_vector(); auto& bprop_vec = bprops[j]->get_vector();
result_pos[j] = result_pos[j] = std::copy(bprop_vec.begin(), bprop_vec.end(), result_pos[j]);
results[j]->get_vector().insert(result_pos[j], bprop_vec.begin(), bprop_vec.end());
} }
} }
......
...@@ -90,6 +90,7 @@ namespace ngraph ...@@ -90,6 +90,7 @@ namespace ngraph
double rtol, double rtol,
double atol); double atol);
/// @brief Contains the information in a Function, but can be used to construct derived functions such as derivatives.
class FunctionSpec class FunctionSpec
{ {
public: public:
...@@ -125,6 +126,9 @@ namespace ngraph ...@@ -125,6 +126,9 @@ namespace ngraph
std::vector<std::shared_ptr<op::Parameter>> m_parameters; std::vector<std::shared_ptr<op::Parameter>> m_parameters;
}; };
/// @brief Returns a FunctionSpec for the backprop derivative of its argument.
/// @param f is f(X_i...)
/// @returns f'(c, X_i...) -> tuple of tensors in same order as in X_i
std::shared_ptr<ngraph::runtime::FunctionSpec> std::shared_ptr<ngraph::runtime::FunctionSpec>
derivative(const std::shared_ptr<ngraph::runtime::FunctionSpec>& f); derivative(const std::shared_ptr<ngraph::runtime::FunctionSpec>& f);
......
...@@ -91,7 +91,7 @@ TEST(backwards, multiply) ...@@ -91,7 +91,7 @@ TEST(backwards, multiply)
auto f_num = make_graph(); auto f_num = make_graph();
auto results_num = auto results_num =
runtime::numeric_derivative<element::Float32>(manager, backend, f_num, {x0, x1}, .01f); runtime::numeric_derivative<element::Float32>(manager, backend, f_num, {x0, x1}, .001f);
auto f_sym = make_graph(); auto f_sym = make_graph();
auto results_sym = auto results_sym =
runtime::backwards_derivative<element::Float32>(manager, backend, f_sym, {x0, x1}); runtime::backwards_derivative<element::Float32>(manager, backend, f_sym, {x0, x1});
...@@ -99,7 +99,7 @@ TEST(backwards, multiply) ...@@ -99,7 +99,7 @@ TEST(backwards, multiply)
{ {
auto result_num = results_num[i]; auto result_num = results_num[i];
auto result_sym = results_sym[i]; auto result_sym = results_sym[i];
bool ac = all_close(result_num, result_sym); bool ac = all_close(result_num, result_sym, .01f, .01f);
EXPECT_TRUE(ac); EXPECT_TRUE(ac);
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment