Unverified Commit 07ba1bef authored by Matthew Brookhart's avatar Matthew Brookhart Committed by GitHub

fix boolean ops to return the input element::type instead of float32 (#356)

parent 8c4ae5ea
...@@ -28,9 +28,8 @@ void ngraph::op::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -28,9 +28,8 @@ void ngraph::op::Maximum::generate_adjoints(autodiff::Adjoints& adjoints,
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
auto y = get_input_op(1); auto y = get_input_op(1);
adjoints.add_delta( adjoints.add_delta(
x, delta * make_shared<op::Convert>(make_shared<op::Greater>(x, y), element::f32)); x, delta * make_shared<op::Convert>(make_shared<op::Greater>(x, y), x->get_element_type()));
adjoints.add_delta( adjoints.add_delta(
y, delta * make_shared<op::Convert>(make_shared<op::Greater>(y, x), element::f32)); y, delta * make_shared<op::Convert>(make_shared<op::Greater>(y, x), y->get_element_type()));
} }
...@@ -29,8 +29,8 @@ void ngraph::op::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -29,8 +29,8 @@ void ngraph::op::Minimum::generate_adjoints(autodiff::Adjoints& adjoints,
auto x = get_input_op(0); auto x = get_input_op(0);
auto y = get_input_op(1); auto y = get_input_op(1);
adjoints.add_delta(x, adjoints.add_delta(
delta * make_shared<op::Convert>(make_shared<op::Less>(x, y), element::f32)); x, delta * make_shared<op::Convert>(make_shared<op::Less>(x, y), x->get_element_type()));
adjoints.add_delta(y, adjoints.add_delta(
delta * make_shared<op::Convert>(make_shared<op::Less>(y, x), element::f32)); y, delta * make_shared<op::Convert>(make_shared<op::Less>(y, x), y->get_element_type()));
} }
...@@ -56,9 +56,10 @@ void ngraph::op::Select::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -56,9 +56,10 @@ void ngraph::op::Select::generate_adjoints(autodiff::Adjoints& adjoints,
auto x = get_inputs().at(1).get_output().get_node(); auto x = get_inputs().at(1).get_output().get_node();
auto y = get_inputs().at(2).get_output().get_node(); auto y = get_inputs().at(2).get_output().get_node();
auto p_as_float = std::make_shared<op::Convert>(p, element::f32); auto p_as_x_type = std::make_shared<op::Convert>(p, x->get_element_type());
auto not_p_as_float = std::make_shared<op::Convert>(std::make_shared<op::Not>(p), element::f32); auto not_p_as_y_type =
std::make_shared<op::Convert>(std::make_shared<op::Not>(p), y->get_element_type());
adjoints.add_delta(x, delta * p_as_float); adjoints.add_delta(x, delta * p_as_x_type);
adjoints.add_delta(y, delta * not_p_as_float); adjoints.add_delta(y, delta * not_p_as_y_type);
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment