Unverified Commit 9ebedbbf authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Add Output support for a number of builders (#3297)

* Add Output support for a number of builders

* ../src/ngraph/builder/autobroadcast.cpp

* Remove some _values

* Simplify

* Simplify, add some more ops/builders

* ONNX failures

* Some GetOutputElement changes to help with Output<Node>

* Convert some ops to use Output<Node> inputs

* Review comments

* Some more changes of nodes to outputs

* Chane names for OutputVector helpers to avoid API change

* Cleanup
parent 8b0a2d19
......@@ -34,18 +34,18 @@ namespace ngraph
{
namespace detail
{
shared_ptr<Node> lp_norm(const shared_ptr<Node>& node,
shared_ptr<Node> lp_norm(const Output<Node>& value,
size_t p_norm,
const AxisSet& reduction_axes,
float bias)
{
// In general "entrywise" lp-norm for matrix `A` is defined as following double sum:
// ||A||_p = ||vec(A)||_p = [sum_{i=1}^m sum_{j=1}^n abs(a_{i,j})^p]^{1/p}
shared_ptr<Node> abs_values{make_shared<op::Abs>(node)};
shared_ptr<Node> abs_values{make_shared<op::Abs>(value)};
shared_ptr<Node> p_node = op::Constant::create(
node->get_element_type(),
node->get_shape(),
vector<float>(shape_size(node->get_shape()), static_cast<float>(p_norm)));
value.get_element_type(),
value.get_shape(),
vector<float>(shape_size(value.get_shape()), static_cast<float>(p_norm)));
// Get inner part of equation: abs_values^p_node, then sum over reduction_axes.
shared_ptr<Node> values{make_shared<op::Power>(abs_values, p_node)};
......@@ -68,26 +68,26 @@ namespace ngraph
}
}
shared_ptr<Node> l0_norm(const shared_ptr<Node>& node, const AxisSet& reduction_axes)
shared_ptr<Node> l0_norm(const Output<Node>& value, const AxisSet& reduction_axes)
{
// L0 norm returns number of elements different from zero.
shared_ptr<Node> zero_node{
op::Constant::create(node->get_element_type(),
node->get_shape(),
vector<float>(shape_size(node->get_shape()), 0.f))};
op::Constant::create(value.get_element_type(),
value.get_shape(),
vector<float>(shape_size(value.get_shape()), 0.f))};
// Convert bool values to input node data type.
shared_ptr<Node> non_zero_values = make_shared<op::Convert>(
make_shared<op::NotEqual>(node, zero_node), node->get_element_type());
make_shared<op::NotEqual>(value, zero_node), value.get_element_type());
return make_shared<op::Sum>(non_zero_values, reduction_axes);
}
shared_ptr<Node>
l1_norm(const shared_ptr<Node>& node, const AxisSet& reduction_axes, float bias)
l1_norm(const Output<Node>& value, const AxisSet& reduction_axes, float bias)
{
shared_ptr<Node> values{
make_shared<op::Sum>(make_shared<op::Abs>(node), reduction_axes)};
make_shared<op::Sum>(make_shared<op::Abs>(value), reduction_axes)};
shared_ptr<Node> bias_node{
op::Constant::create(values->get_element_type(),
......@@ -98,9 +98,9 @@ namespace ngraph
}
shared_ptr<Node>
l2_norm(const shared_ptr<Node>& node, const AxisSet& reduction_axes, float bias)
l2_norm(const Output<Node>& value, const AxisSet& reduction_axes, float bias)
{
shared_ptr<Node> values{make_shared<op::Sum>(node * node, reduction_axes)};
shared_ptr<Node> values{make_shared<op::Sum>(value * value, reduction_axes)};
shared_ptr<Node> bias_node{
op::Constant::create(values->get_element_type(),
......@@ -110,7 +110,7 @@ namespace ngraph
return {make_shared<op::Sqrt>(values + bias_node)};
}
shared_ptr<Node> lp_norm(const shared_ptr<Node>& node,
shared_ptr<Node> lp_norm(const Output<Node>& value,
const AxisSet& reduction_axes,
size_t p_norm,
float bias)
......@@ -118,22 +118,22 @@ namespace ngraph
// The number of non-zero elements
if (p_norm == 0)
{
return l0_norm(node, reduction_axes);
return l0_norm(value, reduction_axes);
}
// sum of absolute values.
else if (p_norm == 1)
{
return l1_norm(node, reduction_axes, bias);
return l1_norm(value, reduction_axes, bias);
}
// sqrt of sum of squares - Euclidean norm
else if (p_norm == 2)
{
return l2_norm(node, reduction_axes, bias);
return l2_norm(value, reduction_axes, bias);
}
// generic case
else
{
return detail::lp_norm(node, p_norm, reduction_axes, bias);
return detail::lp_norm(value, p_norm, reduction_axes, bias);
}
}
......
......@@ -25,62 +25,57 @@ namespace ngraph
{
namespace builder
{
/// \brief Creates node which calculates L-0 norm of input tensor.
/// \brief Calculates L-0 norm of input tensor.
///
/// \note The L-0 norm represents the cardinality of elements different
/// from zero. This actually is not a "true" norm.
///
/// \param[in] node The input tensor node.
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
///
/// \return Node which calculates L-0 norm values.
/// \return L-0 norm of value.
///
std::shared_ptr<Node> l0_norm(const std::shared_ptr<Node>& node,
const AxisSet& reduction_axes);
std::shared_ptr<Node> l0_norm(const Output<Node>& value, const AxisSet& reduction_axes);
/// \brief Creates node which calculates L-1 norm of input tensor.
/// \brief Calculates L-1 norm of a value.
///
/// \note The L-1 norm represents the sum of absolute values.
///
/// \param[in] node The input tensor node.
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] bias The bias added to the calculated sum.
///
/// \return Node which calculates L-1 norm values.
/// \return L-1 norm of value.
///
std::shared_ptr<Node> l1_norm(const std::shared_ptr<Node>& node,
const AxisSet& reduction_axes,
float bias = 0.f);
std::shared_ptr<Node>
l1_norm(const Output<Node>& value, const AxisSet& reduction_axes, float bias = 0.f);
/// \brief Calculates L-2 norm of input tensor.
///
/// \note The L-2 norm represents the square root of sum of squares of each
/// individual element.
///
/// \param[in] node The input tensor node.
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] bias The bias added to the calculated sum.
///
/// \return Node which calculates L-2 norm values.
/// \return L-2 norm of value.
///
std::shared_ptr<Node> l2_norm(const std::shared_ptr<Node>& node,
const AxisSet& reduction_axes,
float bias = 0.f);
std::shared_ptr<Node>
l2_norm(const Output<Node>& value, const AxisSet& reduction_axes, float bias = 0.f);
/// \brief Creates node which calculates L-p norm on input tensor.
///
/// \param[in] node The input nGraph tensor.
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] p_norm The p norm to calculate.
/// \param[in] bias The bias added to the calculated sum.
///
/// \return Node which calculates L-p norm.
/// \return L-p norm of value.
///
std::shared_ptr<Node> lp_norm(const std::shared_ptr<Node>& node,
std::shared_ptr<Node> lp_norm(const Output<Node>& value,
const AxisSet& reduction_axes,
std::size_t p_norm = 2,
float bias = 0.f);
} // namespace builder
} // namespace ngraph
......@@ -42,8 +42,7 @@ namespace ngraph
return N;
}
std::shared_ptr<Node> l2_norm(const std::shared_ptr<Node>& node,
const AxisSet& reduction_axes)
std::shared_ptr<Node> l2_norm(const Output<Node>& node, const AxisSet& reduction_axes)
{
auto x2 = node * node;
auto x2sum = std::make_shared<op::Sum>(x2, reduction_axes);
......@@ -51,19 +50,19 @@ namespace ngraph
return std::make_shared<op::Sqrt>(x2sum);
}
std::shared_ptr<Node> mean(const std::shared_ptr<Node>& node, const AxisSet& reduction_axes)
std::shared_ptr<Node> mean(const Output<Node>& value, const AxisSet& reduction_axes)
{
auto xsum = std::make_shared<op::Sum>(node, reduction_axes);
auto xsum = std::make_shared<op::Sum>(value, reduction_axes);
auto N = get_num_elements(node->get_shape(), reduction_axes);
const auto& et = node->get_element_type();
auto N = get_num_elements(value.get_shape(), reduction_axes);
const auto& et = value.get_element_type();
auto divisor = op::Constant::create(et, xsum->get_shape(), {N});
return xsum / divisor;
}
std::shared_ptr<Node> std_dev(const std::shared_ptr<Node>& node,
std::shared_ptr<Node> std_dev(const Output<Node>& node,
const AxisSet& reduction_axes,
const bool bessel_correction)
{
......@@ -74,13 +73,13 @@ namespace ngraph
// The second might be more numerically stable/easier to pattern match
// It also requires adding a broadcast op, and would probably be slower
// TODO(mbrookhart): Switch to E[(X-\mu)^2]?
std::shared_ptr<Node> variance(const std::shared_ptr<Node>& node,
std::shared_ptr<Node> variance(const Output<Node>& value,
const AxisSet& reduction_axes,
const bool bessel_correction)
{
std::shared_ptr<Node> mu = mean(node, reduction_axes);
std::shared_ptr<Node> mu = mean(value, reduction_axes);
auto reshape = node->get_shape();
auto reshape = value.get_shape();
for (auto i : reduction_axes)
{
reshape[i] = 1;
......@@ -90,21 +89,21 @@ namespace ngraph
mu = std::make_shared<op::Reshape>(mu, order, reshape);
std::shared_ptr<Node> diff = make_with_numpy_broadcast<op::Subtract>(node, mu);
Output<Node> diff = make_with_numpy_broadcast<op::Subtract>(value, mu);
diff = std::make_shared<op::Sum>(diff * diff, reduction_axes);
const auto& et = node->get_element_type();
auto N = get_num_elements(node->get_shape(), reduction_axes);
const auto& et = value.get_element_type();
auto N = get_num_elements(value.get_shape(), reduction_axes);
if (bessel_correction)
{
auto N1const = op::Constant::create(et, diff->get_shape(), {N - 1});
auto N1const = op::Constant::create(et, diff.get_shape(), {N - 1});
return diff / N1const;
}
else
{
auto Nconst = op::Constant::create(et, diff->get_shape(), {N});
auto Nconst = op::Constant::create(et, diff.get_shape(), {N});
return diff / Nconst;
}
}
......
......@@ -35,7 +35,7 @@ namespace ngraph
///
/// | | Type | Description |
/// | ---------------- | --------------------------------- | ----------------------------------------------------------------------------------------------------- |
/// | `node` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any shape
/// | `value` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any shape
/// | `reduction_axes` | AxesSet | The axes to eliminate through reduction (0 indexed). |
///
/// ## Output
......@@ -43,8 +43,7 @@ namespace ngraph
/// | Type | Description |
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
std::shared_ptr<Node> l2_norm(const std::shared_ptr<Node>& node,
const AxisSet& reduction_axes);
std::shared_ptr<Node> l2_norm(const Output<Node>& value, const AxisSet& reduction_axes);
/// \brief Sum-based Mean of a Tensor.
///
......@@ -66,8 +65,7 @@ namespace ngraph
/// | Type | Description |
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
std::shared_ptr<Node> mean(const std::shared_ptr<Node>& node,
const AxisSet& reduction_axes);
std::shared_ptr<Node> mean(const Output<Node>& node, const AxisSet& reduction_axes);
/// \brief Sum-based Standard Deviation of a Tensor.
///
......@@ -85,7 +83,7 @@ namespace ngraph
///
/// | | Type | Description |
/// | ------------------- | --------------------------------- | ----------------------------------------------------------------------------------------------------- |
/// | `node` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any shape
/// | `value` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any shape
/// | `reduction_axes` | AxesSet | The axes to eliminate through reduction (0 indexed). |
/// | `bessel_correction` | bool (default = false) | Enable Bessel's correction to std_dev for Small sample sizes |
///
......@@ -94,7 +92,7 @@ namespace ngraph
/// | Type | Description |
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
std::shared_ptr<Node> std_dev(const std::shared_ptr<Node>& node,
std::shared_ptr<Node> std_dev(const Output<Node>& value,
const AxisSet& reduction_axes,
const bool bessel_correction = false);
......@@ -114,7 +112,7 @@ namespace ngraph
///
/// | | Type | Description |
/// | ------------------- | --------------------------------- | ----------------------------------------------------------------------------------------------------- |
/// | `node` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any shape
/// | `value | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any shape
/// | `reduction_axes` | AxesSet | The axes to eliminate through reduction (0 indexed). |
/// | `bessel_correction` | bool (default = false) | Enable Bessel's correction to std_dev for Small sample sizes |
///
......@@ -123,7 +121,7 @@ namespace ngraph
/// | Type | Description |
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
std::shared_ptr<Node> variance(const std::shared_ptr<Node>& node,
std::shared_ptr<Node> variance(const Output<Node>& value,
const AxisSet& reduction_axes,
const bool bessel_correction = false);
......
......@@ -27,14 +27,14 @@
using namespace ngraph;
using namespace std;
shared_ptr<Node> builder::reshape(const shared_ptr<Node>& node, const Shape& shape)
shared_ptr<Node> builder::reshape(const Output<Node>& value, const Shape& shape)
{
return make_shared<op::Reshape>(node, get_default_order(node->get_shape().size()), shape);
return make_shared<op::Reshape>(value, get_default_order(value.get_shape().size()), shape);
}
shared_ptr<Node> builder::reorder_axes(const shared_ptr<Node>& node, vector<size_t> axes_order = {})
shared_ptr<Node> builder::reorder_axes(const Output<Node>& value, vector<size_t> axes_order)
{
Shape out_shape = node->get_shape();
Shape out_shape = value.get_shape();
if (axes_order.empty())
{
axes_order.resize(out_shape.size());
......@@ -44,25 +44,25 @@ shared_ptr<Node> builder::reorder_axes(const shared_ptr<Node>& node, vector<size
{
for (size_t i = 0; i < axes_order.size(); ++i)
{
out_shape[i] = node->get_shape().at(axes_order.at(i));
out_shape[i] = value.get_shape().at(axes_order.at(i));
}
}
auto axis_vector = AxisVector{begin(axes_order), end(axes_order)};
return make_shared<op::Reshape>(node, axis_vector, out_shape);
return make_shared<op::Reshape>(value, axis_vector, out_shape);
}
shared_ptr<Node> builder::transpose(const shared_ptr<Node>& node)
shared_ptr<Node> builder::transpose(const Output<Node>& value)
{
vector<size_t> axes_order(node->get_shape().size());
vector<size_t> axes_order(value.get_shape().size());
iota(begin(axes_order), end(axes_order), 0);
reverse(begin(axes_order), end(axes_order));
return builder::reorder_axes(node, axes_order);
return builder::reorder_axes(value, axes_order);
}
shared_ptr<Node> builder::flatten(const shared_ptr<Node>& node, int axis)
shared_ptr<Node> builder::flatten(const Output<Node>& value, int axis)
{
auto data_shape = node->get_shape();
auto data_shape = value.get_shape();
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of input tensor.
// The last dimension is the product of the rest of input tensor dimensions: [d_{axis}, ..., d_n]
......@@ -73,5 +73,5 @@ shared_ptr<Node> builder::flatten(const shared_ptr<Node>& node, int axis)
accumulate(next(begin(data_shape), axis), end(data_shape), 1UL, multiplies<size_t>());
return make_shared<op::Reshape>(
node, get_default_order(data_shape.size()), Shape{first_dim_size, last_dim_size});
value, get_default_order(data_shape.size()), Shape{first_dim_size, last_dim_size});
}
......@@ -27,37 +27,37 @@ namespace ngraph
{
namespace builder
{
/// \brief Change shape of input tensor.
/// \brief Change shape of a value
///
/// \param[in] node The node producing the tensor to be reshaped.
/// \param[in] shape The new shape for input tensor.
/// \param[in] value The value to be reshaped.
/// \param[in] shape The new shape.
///
/// \return The node representing a Reshape operation.
/// \return The reshaped value.
///
std::shared_ptr<Node> reshape(const std::shared_ptr<Node>& node, const Shape& shape);
std::shared_ptr<Node> reshape(const Output<Node>& value, const Shape& shape);
/// \brief Permute axes according to specified axes_order parameter.
///
/// \param node The node which axes we want to permute.
/// \param axes_order The permutation of node tensor axes.
/// \param value The vlaue whose axes we want to permute.
/// \param axes_order The permutation of axes.
///
/// \return: New node with permuted axes.
std::shared_ptr<Node> reorder_axes(const std::shared_ptr<Node>& node,
std::vector<std::size_t> axes_order);
/// \return: Value with permuted axes.
std::shared_ptr<Node> reorder_axes(const Output<Node>& value,
std::vector<size_t> axes_order = {});
/// \brief Return transposed tensor (with axes in reversed order).
/// \brief Return transposed vlaue (with axes in reversed order).
///
/// \param node Input tensor we want to transpose
/// \param value Value to transpose.
///
/// \return: New node with reversed dimensions.
std::shared_ptr<Node> transpose(const std::shared_ptr<Node>& node);
/// \return: Value with reversed dimensions.
std::shared_ptr<Node> transpose(const Output<Node>& value);
/// \brief Flatten the input tensor into a 2D matrix.
/// \brief Flatten a value into a 2D matrix.
///
/// \param node The tensor to be flattened.
/// \param value The tensor to be flattened.
/// \param axis The axis dividing shape.
///
/// \return The new node will be a 2D matrix representing the flattened input node.
std::shared_ptr<Node> flatten(const std::shared_ptr<Node>& node, int axis);
/// \return The new value will be a 2D matrix representing the flattened input node.
std::shared_ptr<Node> flatten(const Output<Node>& value, int axis);
} // namespace builder
} // namespace ngraph
......@@ -21,31 +21,31 @@ using namespace ngraph;
namespace
{
inline std::size_t get_valid_array_index(std::size_t idx, std::size_t axis_size)
inline size_t get_valid_array_index(size_t idx, size_t axis_size)
{
return std::min(idx, axis_size);
}
std::shared_ptr<op::Slice> make_ng_slice(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& axes,
const std::vector<std::size_t>& starts,
const std::vector<std::size_t>& ends)
std::shared_ptr<op::Slice> make_ng_slice(const Output<Node>& output,
const std::vector<size_t>& axes,
const std::vector<size_t>& starts,
const std::vector<size_t>& ends)
{
std::vector<std::size_t> upper_bounds{node->get_shape()};
std::vector<std::size_t> lower_bounds(upper_bounds.size());
for (std::size_t index{0}; index < axes.size(); ++index)
std::vector<size_t> upper_bounds{output.get_shape()};
std::vector<size_t> lower_bounds(upper_bounds.size());
for (size_t index{0}; index < axes.size(); ++index)
{
std::size_t axis{axes.at(index)};
size_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), node->get_shape().at(axis));
get_valid_array_index(starts.at(index), output.get_shape().at(axis));
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
get_valid_array_index(ends.at(index), output.get_shape().at(axis));
}
return std::make_shared<op::Slice>(node, lower_bounds, upper_bounds);
return std::make_shared<op::Slice>(output, lower_bounds, upper_bounds);
}
}
NodeVector builder::split(const std::shared_ptr<ngraph::Node>& node,
NodeVector builder::split(const Output<ngraph::Node>& value,
const std::vector<size_t>& length_parts,
size_t axis)
{
......@@ -54,21 +54,21 @@ NodeVector builder::split(const std::shared_ptr<ngraph::Node>& node,
for (const auto& length_part : length_parts)
{
size_t end_index{start_index + length_part};
outputs.push_back(make_ng_slice(node, {axis}, {start_index}, {end_index}));
outputs.push_back(make_ng_slice(value, {axis}, {start_index}, {end_index}));
start_index = end_index;
}
return outputs;
}
NodeVector builder::split(const std::shared_ptr<ngraph::Node>& node, size_t split_parts, int axis)
NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axis)
{
size_t axis_to_split{static_cast<size_t>(axis)};
if (axis < 0)
{
axis_to_split = node->get_shape().size() + axis;
axis_to_split = value.get_shape().size() + axis;
}
size_t length_axis_to_split{node->get_shape().at(axis_to_split)};
size_t length_axis_to_split{value.get_shape().at(axis_to_split)};
std::vector<size_t> length_parts(split_parts, length_axis_to_split / split_parts);
return split(node, length_parts, axis_to_split);
return split(value, length_parts, axis_to_split);
}
......@@ -23,22 +23,22 @@ namespace ngraph
{
namespace builder
{
/// \brief Split node on specified axis into multiple parts.
/// \brief Split value on specified axis into multiple parts.
///
/// \param[in] node The input node.
/// \param[in] value The value to be split.
/// \param[in] length_parts The vector defining the lengths of each split part.
/// \param[in] axis The axis we split input node on. Default value is zero axis.
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector split(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& length_parts,
std::size_t axis = 0);
NodeVector split(const Output<Node>& value,
const std::vector<size_t>& length_parts,
size_t axis = 0);
/// \brief Split node on specified axis into multiple parts.
///
/// \param[in] node The input node.
/// \param[in] split_parts The number of parts we want to split input node at given
/// \param[in] value The value to split.
/// \param[in] split_parts The number of parts we want to split output at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param[in] axis The axis we split input node on. Default value is zero axis.
......@@ -49,7 +49,6 @@ namespace ngraph
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector
split(const std::shared_ptr<ngraph::Node>& node, std::size_t split_parts, int axis = 0);
NodeVector split(const Output<Node>& value, size_t split_parts, int axis = 0);
} // namespace builder
} // namespace ngraph
......@@ -84,8 +84,7 @@ void op::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto order = ngraph::get_default_order(zsum->get_shape());
auto zreshape = make_shared<op::Reshape>(zsum, order, shape);
auto adjoint =
z - builder::make_with_numpy_broadcast<op::Multiply>(shared_from_this(), zreshape);
auto adjoint = z - builder::make_with_numpy_broadcast<op::Multiply>(output(0), zreshape);
auto x = get_argument(0);
adjoints.add_delta(x, adjoint);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment