Commit 640295cf authored by baojun's avatar baojun Committed by Scott Cyphers

Add PDPD style autobroadcast (#3645)

* add pdpd autob spec and ut

* add pdpd autob to pass check

* add dummy dynamic ut

* add dynamic test case

* use int64_t

* pass dynamic ut

* add pdpd style kernel

* handle pdpd validation

* add dummy pdpd style bcast

* implement pdpd style bcast

* add validation

* use output arg

* use ngraph_check
Co-Authored-By: 's avatarAdam Procter <adam.m.procter@intel.com>

* enable error print

* accuracy mismatch in plaidml

* fix ut on windows

* make separated tests

* fix trailing one case

* fix warning
parent bad0c25b
......@@ -362,6 +362,8 @@ set (SRC
op/fused/squeeze.hpp
op/fused/unsqueeze.cpp
op/fused/unsqueeze.hpp
op/util/attr_types.cpp
op/util/attr_types.hpp
op/util/activation_functions.cpp
op/util/activation_functions.hpp
op/util/arithmetic_reduction.cpp
......
......@@ -724,7 +724,8 @@ std::tuple<element::Type, PartialShape>
PartialShape::merge_into(pshape, get_input_partial_shape(i)),
"Argument shapes are inconsistent.");
}
else if (autob.m_type == op::AutoBroadcastType::NUMPY)
else if (autob.m_type == op::AutoBroadcastType::NUMPY ||
autob.m_type == op::AutoBroadcastType::PDPD)
{
NODE_VALIDATION_CHECK(
this,
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/util/attr_types.hpp"
using namespace ngraph;
std::ostream& op::operator<<(std::ostream& s, const op::AutoBroadcastType& type)
{
switch (type)
{
case op::AutoBroadcastType::NONE: s << "NONE"; break;
case op::AutoBroadcastType::NUMPY: s << "NUMPY"; break;
case op::AutoBroadcastType::PDPD: s << "PDPD"; break;
default: s << "Undefined Type";
}
return s;
}
......@@ -17,6 +17,7 @@
#pragma once
#include <cstddef>
#include <ostream>
namespace ngraph
{
......@@ -73,12 +74,29 @@ namespace ngraph
/// A: Shape(2, 1, 6)
/// B: Shape( 3, 1)
/// Result: Shape(2, 3, 6)
/// PDPD - PaddlePaddle-style implicit broadcasting
/// (https://github.com/PaddlePaddle/Paddle/blob/release/1.5/paddle/
/// fluid/operators/elementwise/elementwise_op.h#L126)
/// Broadcast B to match the shape of A, where axis is the start
/// dimension index to align B with A. If axis is -1 (default), i
/// axis = rank(A) - rank(B). The trailing dimensions of size 1 for B
/// will be ignored.
///
/// E.g.,
/// A: Shape(2, 3, 4, 5)
/// B: Shape( 3, 4 ) with axis =1
/// Result: Shape(2, 3, 4, 5)
///
/// A: Shape(2, 3, 4, 5)
/// B: Shape( 3, 1 ) with axis = 1
/// Result: Shape(2, 3, 4, 5)
///
/// TODO: Add more implicit broadcast modes used by frameworks
enum class AutoBroadcastType
{
NONE = 0,
NUMPY
NUMPY,
PDPD
};
/// \brief Specifies how eps is combined with L2 value
......@@ -103,14 +121,16 @@ namespace ngraph
, m_axis(0)
{
}
AutoBroadcastSpec(AutoBroadcastType type, size_t axis)
AutoBroadcastSpec(AutoBroadcastType type, int64_t axis)
: m_type(type)
, m_axis(axis)
{
}
AutoBroadcastType m_type; // Implicit broadcasting algorithm
size_t m_axis; // Axis to start alignment on
int64_t m_axis; // Axis to start alignment on
};
std::ostream& operator<<(std::ostream& s, const AutoBroadcastType& type);
}
}
......@@ -172,6 +172,59 @@ static std::shared_ptr<ngraph::Node>
return std::make_shared<ngraph::op::Broadcast>(broadcasted_value, output_shape, broadcast_axes);
}
/// \brief Broadcast input node.
///
/// \param[in] value The input Node to be broadcast.
/// \param[in] output_shape The output shape.
/// \param[in] axis The start index to align with output_shape
///
/// \return The broadcasted Node.
///
static std::shared_ptr<ngraph::Node> broadcast_value_pdpd_style(
const ngraph::Output<ngraph::Node>& value, const ngraph::Shape& output_shape, int64_t axis)
{
auto value_shape = value.get_shape();
// If node already has the required shape, return original node
if (output_shape == value_shape)
{
return value.as_single_output_node();
}
if (axis == -1)
{
axis = output_shape.size() - value_shape.size();
}
auto trimmed_value_shape = value_shape;
while (trimmed_value_shape.size() > 0 && trimmed_value_shape.back() == 1)
{
trimmed_value_shape.pop_back();
}
ngraph::AxisSet axes;
for (int64_t i = 0; i < axis; ++i)
{
axes.insert(static_cast<size_t>(i));
}
for (size_t i = axis + trimmed_value_shape.size(); i < output_shape.size(); ++i)
{
axes.insert(i);
}
auto trimmed_value = value;
if (value_shape != trimmed_value_shape)
{
trimmed_value = std::make_shared<ngraph::op::Reshape>(
value, ngraph::get_default_order(value_shape), trimmed_value_shape);
}
auto value_bcast = std::make_shared<ngraph::op::Broadcast>(trimmed_value, output_shape, axes);
return value_bcast;
}
namespace ngraph
{
namespace op
......@@ -415,6 +468,22 @@ namespace ngraph
return {left, broadcast_right};
}
NodeVector pdpd_style_broadcast(const NodeVector& inputs, int64_t axis)
{
if (inputs.size() <= 1)
{
return inputs;
}
NodeVector broadcasted_inputs{inputs[0]};
for (std::size_t i = 1; i < inputs.size(); ++i)
{
broadcasted_inputs.push_back(
broadcast_value_pdpd_style(inputs[i], inputs[0]->get_shape(), axis));
}
return broadcasted_inputs;
}
AxisSet calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis)
......
......@@ -133,6 +133,15 @@ namespace ngraph
OutputVector numpy_style_broadcast_values_for_matmul_operation(const Output<Node>& left,
const Output<Node>& right);
/// \brief Cast shape of all input nodes for an element-wise operation that requires
/// shape-compatibility
///
/// \param inputs Original list of inputs
/// \param axis Index starting to align
///
/// \return pdpd-style broadcasted list of nodes.
NodeVector pdpd_style_broadcast(const NodeVector& inputs, int64_t axis);
/// \brief Generate a list of broadcast axes.
///
/// \details Informally, a broadcast "adds" axes to the input tensor, replicating
......
......@@ -250,30 +250,81 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst,
const PartialShape& src,
const op::AutoBroadcastSpec& autob)
{
NGRAPH_CHECK(autob.m_type == op::AutoBroadcastType::NUMPY, "Unsupported auto broadcast type");
if (dst.rank().is_dynamic() || src.rank().is_dynamic())
switch (autob.m_type)
{
dst = PartialShape::dynamic();
return true;
case op::AutoBroadcastType::NONE: return true;
case op::AutoBroadcastType::NUMPY:
{
if (dst.rank().is_dynamic() || src.rank().is_dynamic())
{
dst = PartialShape::dynamic();
return true;
}
else
{
// Ranks are both static.
auto dst_rank = size_t(dst.rank());
auto src_rank = size_t(src.rank());
auto new_rank = std::max(dst_rank, src_rank);
std::vector<Dimension> dims(new_rank);
bool success = true;
for (size_t i = 0; i < new_rank; i++)
{
auto dsti =
i < (new_rank - dst_rank) ? Dimension(1) : dst[i - (new_rank - dst_rank)];
auto srci =
i < (new_rank - src_rank) ? Dimension(1) : src[i - (new_rank - src_rank)];
success &= Dimension::broadcast_merge(dims[i], dsti, srci);
}
dst = PartialShape(dims);
return success;
}
}
else
case op::AutoBroadcastType::PDPD:
{
// Ranks are both static.
auto dst_rank = size_t(dst.rank());
auto src_rank = size_t(src.rank());
auto new_rank = std::max(dst_rank, src_rank);
std::vector<Dimension> dims(new_rank);
bool success = true;
for (size_t i = 0; i < new_rank; i++)
if (dst.rank().is_dynamic() || src.rank().is_dynamic())
{
auto dsti = i < (new_rank - dst_rank) ? Dimension(1) : dst[i - (new_rank - dst_rank)];
auto srci = i < (new_rank - src_rank) ? Dimension(1) : src[i - (new_rank - src_rank)];
success &= Dimension::broadcast_merge(dims[i], dsti, srci);
return true;
}
dst = PartialShape(dims);
return success;
else
{
// Ranks are both static.
auto dst_rank = size_t(dst.rank());
auto src_rank = size_t(src.rank());
if (dst_rank == src_rank && dst.compatible(src))
return true;
int64_t axis = autob.m_axis;
if (axis < -1)
{
return false;
}
if (axis == -1)
{
axis = dst_rank - src_rank;
}
size_t len = src_rank;
while (len > 0 && src[len - 1].is_static() && size_t(src[len - 1]) == 1)
{
--len;
}
for (size_t i = axis; i < axis + len; ++i)
{
if (!(dst[i].compatible(src[i - axis])))
{
return false;
}
}
return true;
}
}
default: NGRAPH_CHECK(false, "Unsupported auto broadcast type: ", autob.m_type);
}
return false;
}
bool PartialShape::all_non_negative() const
......
......@@ -28,14 +28,19 @@ namespace ngraph
NodeVector rc;
if (node->supports_auto_broadcast())
{
if (node->get_autob().m_type == op::AutoBroadcastType::NONE)
auto autob = node->get_autob();
if (autob.m_type == op::AutoBroadcastType::NONE)
{
rc = node->get_arguments();
}
else if (node->get_autob().m_type == op::AutoBroadcastType::NUMPY)
else if (autob.m_type == op::AutoBroadcastType::NUMPY)
{
rc = op::numpy_style_broadcast(node->get_arguments());
}
else if (autob.m_type == op::AutoBroadcastType::PDPD)
{
rc = op::pdpd_style_broadcast(node->get_arguments(), autob.m_axis);
}
else
{
throw ngraph_error("Unsupported implicit broadcast type");
......
......@@ -95,6 +95,8 @@ model_reverse_sequence_1_batch_0
model_dequantize_linear_scalar_zero_scale_int8
model_softmax
avg_pool_3d_uneven_strided_padded
auto_bcast_binary_elementwise_pdpd
auto_bcast_binary_elementwise_pdpd_dynamic
rnn_cell_activation_function
gru_cell_bias_clip
gru_cell_linear_before_reset
......
......@@ -84,60 +84,140 @@ namespace ngraph
// Output shape
// ------------
// [ 3, 2, 6]
Shape arg0_padded_shape = arg0_shape;
Shape arg1_padded_shape = arg1_shape;
while (arg0_padded_shape.size() < arg1_padded_shape.size())
{
arg0_padded_shape.insert(arg0_padded_shape.begin(), 1);
}
Shape arg0_padded_shape = arg0_shape;
Shape arg1_padded_shape = arg1_shape;
while (arg1_padded_shape.size() < arg0_padded_shape.size())
{
arg1_padded_shape.insert(arg1_padded_shape.begin(), 1);
}
while (arg0_padded_shape.size() < arg1_padded_shape.size())
{
arg0_padded_shape.insert(arg0_padded_shape.begin(), 1);
}
while (arg1_padded_shape.size() < arg0_padded_shape.size())
{
arg1_padded_shape.insert(arg1_padded_shape.begin(), 1);
}
Shape arg0_squeezed_shape;
Shape arg1_squeezed_shape;
AxisSet arg0_squeezed_axes;
AxisSet arg1_squeezed_axes;
Shape output_shape;
Shape arg0_squeezed_shape;
Shape arg1_squeezed_shape;
AxisSet arg0_squeezed_axes;
AxisSet arg1_squeezed_axes;
Shape output_shape;
for (size_t i = 0; i < arg0_padded_shape.size(); i++)
for (size_t i = 0; i < arg0_padded_shape.size(); i++)
{
if (arg0_padded_shape[i] == 1)
{
arg0_squeezed_axes.insert(i);
}
else
{
arg0_squeezed_shape.push_back(arg0_padded_shape[i]);
}
if (arg1_padded_shape[i] == 1)
{
arg1_squeezed_axes.insert(i);
}
else
{
arg1_squeezed_shape.push_back(arg1_padded_shape[i]);
}
output_shape.push_back(arg0_padded_shape[i] == 1
? arg1_padded_shape[i]
: arg0_padded_shape[i]);
}
CoordinateTransform arg0_transform(arg0_squeezed_shape);
CoordinateTransform arg1_transform(arg1_squeezed_shape);
CoordinateTransform output_transform(output_shape);
for (const Coordinate& output_coord : output_transform)
{
Coordinate arg0_coord = reduce(output_coord, arg0_squeezed_axes);
Coordinate arg1_coord = reduce(output_coord, arg1_squeezed_axes);
out[output_transform.index(output_coord)] =
elementwise_functor(arg0[arg0_transform.index(arg0_coord)],
arg1[arg1_transform.index(arg1_coord)]);
}
}
break;
case op::AutoBroadcastType::PDPD:
// We'll be using CoordinateTransform to handle the broadcasting. No need to
// process arg0 and output shape will be the same as arg0. We need to process
// arg1 and the general procedure is as follows:
//
// (1) Trim trailing ones from arg1 shape.
// (2) Left and right pad arg1 to match arg0 shape. Axis is the index start
// to align between arg0 and arg1.
// (3) Squeeze (remove ones from) arg1 shape, and record the squeezed axis
// indices.
// (3) Using CoordinateTransform, broadcast arg1 to the final output
// shape. The "broadcasted axes" will be those that were squeezed in step
// 23.
//
// Example:
//
// Input shape-> Padded shape-> Squeezed Shape/Squeezed Axes
// ----------- ------------ ----------------------------
// a: [ 3, 4, 5, 6] [ 3, 4, 5, 6] [ 3, 4, 5, 6]
// b: [ 4, 5, ] [ 1, 4, 5, 1] [ 4, 5 ] {0,3}
// | | |
// v v v
// Output shape
// ------------
// [ 3, 4, 5, 6]
{
if (arg0_padded_shape[i] == 1)
int64_t axis = broadcast_spec.m_axis;
if (axis == -1)
{
arg0_squeezed_axes.insert(i);
axis = arg0_shape.size() - arg1_shape.size();
}
else
Shape arg1_padded_shape = arg1_shape;
// Trim trailing ones
while (arg1_padded_shape.size() > 0 && arg1_padded_shape.back() == 1)
{
arg0_squeezed_shape.push_back(arg0_padded_shape[i]);
arg1_padded_shape.pop_back();
}
if (arg1_padded_shape[i] == 1)
for (int64_t i = 0; i < axis; ++i)
{
arg1_squeezed_axes.insert(i);
arg1_padded_shape.insert(arg1_padded_shape.begin(), 1);
}
else
while (arg1_padded_shape.size() < arg0_shape.size())
{
arg1_squeezed_shape.push_back(arg1_padded_shape[i]);
arg1_padded_shape.insert(arg1_padded_shape.end(), 1);
}
output_shape.push_back(arg0_padded_shape[i] == 1 ? arg1_padded_shape[i]
: arg0_padded_shape[i]);
}
Shape arg1_squeezed_shape;
AxisSet arg1_squeezed_axes;
for (size_t i = 0; i < arg0_shape.size(); i++)
{
if (arg1_padded_shape[i] == 1)
{
arg1_squeezed_axes.insert(i);
}
else
{
arg1_squeezed_shape.push_back(arg1_padded_shape[i]);
}
}
CoordinateTransform arg0_transform(arg0_squeezed_shape);
CoordinateTransform arg1_transform(arg1_squeezed_shape);
CoordinateTransform output_transform(output_shape);
CoordinateTransform arg0_transform(arg0_shape);
CoordinateTransform arg1_transform(arg1_squeezed_shape);
CoordinateTransform output_transform(arg0_shape);
for (const Coordinate& output_coord : output_transform)
{
Coordinate arg0_coord = reduce(output_coord, arg0_squeezed_axes);
Coordinate arg1_coord = reduce(output_coord, arg1_squeezed_axes);
out[output_transform.index(output_coord)] =
elementwise_functor(arg0[arg0_transform.index(arg0_coord)],
arg1[arg1_transform.index(arg1_coord)]);
for (const Coordinate& output_coord : output_transform)
{
Coordinate arg1_coord = reduce(output_coord, arg1_squeezed_axes);
out[output_transform.index(output_coord)] =
elementwise_functor(arg0[arg0_transform.index(output_coord)],
arg1[arg1_transform.index(arg1_coord)]);
}
}
}
}
......
......@@ -340,7 +340,7 @@ static op::AutoBroadcastSpec read_auto_broadcast(json js_node, const std::string
{
json j = js_node[attr];
return op::AutoBroadcastSpec(static_cast<op::AutoBroadcastType>(j.at("type")),
j.at("axis").get<size_t>());
j.at("axis").get<int64_t>());
}
else
{
......
......@@ -45,8 +45,10 @@ using namespace ngraph;
static string s_manifest = "${MANIFEST}";
template <typename optype, typename itype, typename otype>
void check_auto_bcast(const std::vector<std::vector<itype>>& inputs,
const std::vector<otype> output)
void check_auto_bcast(
const std::vector<std::vector<itype>>& inputs,
const std::vector<otype> output,
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec(op::AutoBroadcastType::NUMPY))
{
auto iet = element::from<itype>();
auto oet = element::from<otype>();
......@@ -61,8 +63,7 @@ void check_auto_bcast(const std::vector<std::vector<itype>>& inputs,
}
auto A = make_shared<op::Parameter>(iet, Shape{2, 3});
auto B = make_shared<op::Parameter>(iet, Shape{3});
auto f = make_shared<Function>(make_shared<optype>(A, B, op::AutoBroadcastType::NUMPY),
ParameterVector{A, B});
auto f = make_shared<Function>(make_shared<optype>(A, B, autob), ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
......@@ -108,3 +109,85 @@ NGRAPH_TEST(${BACKEND_NAME}, auto_bcast_binary_elementwise)
check_auto_bcast<op::NotEqual, uint8_t, char>({{1, 2, 3, 4, 5, 6}, {1, 5, 8}},
{0, 1, 1, 1, 0, 1});
}
NGRAPH_TEST(${BACKEND_NAME}, auto_bcast_binary_elementwise_pdpd)
{
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec(op::AutoBroadcastType::PDPD, 1);
check_auto_bcast<op::Add, float, float>(
{{1, 2, 3, 4, 5, 6}, {5, 6, 7}}, {6, 8, 10, 9, 11, 13}, autob);
check_auto_bcast<op::Subtract, float, float>(
{{1, 2, 3, 4, 5, 6}, {5, 6, 7}}, {-4.f, -4.f, -4.f, -1.f, -1.f, -1.f}, autob);
check_auto_bcast<op::Multiply, float, float>(
{{1, 2, 3, 4, 5, 6}, {5, 6, 7}}, {5, 12, 21, 20, 30, 42}, autob);
check_auto_bcast<op::Divide, float, float>(
{{4, 5, 6, 7, 8, 9}, {1, 2, 3}}, {4, 2.5f, 2, 7, 4, 3}, autob);
check_auto_bcast<op::Maximum, float, float>(
{{1, 2, 3, 4, 5, 6}, {1, 5, 8}}, {1, 5, 8, 4, 5, 8}, autob);
check_auto_bcast<op::Minimum, float, float>(
{{1, 2, 3, 4, 5, 6}, {1, 5, 8}}, {1, 2, 3, 1, 5, 6}, autob);
check_auto_bcast<op::Power, float, float>(
{{1, 2, 3, 4, 5, 6}, {1, 2, 3}}, {1, 4, 27, 4, 25, 216}, autob);
check_auto_bcast<op::And, char, char>(
{{1, 0, 1, 0, 0, 1}, {1, 0, 1}}, {1, 0, 1, 0, 0, 1}, autob);
check_auto_bcast<op::Or, char, char>(
{{1, 0, 1, 0, 1, 1}, {1, 0, 0}}, {1, 0, 1, 1, 1, 1}, autob);
check_auto_bcast<op::Equal, uint8_t, char>(
{{1, 0, 1, 0, 1, 1}, {1, 0, 0}}, {1, 1, 0, 0, 0, 0}, autob);
check_auto_bcast<op::Greater, float, char>(
{{1, 2, 3, 4, 5, 6}, {1, 5, 8}}, {0, 0, 0, 1, 0, 0}, autob);
check_auto_bcast<op::GreaterEq, float, char>(
{{1, 2, 3, 4, 5, 6}, {1, 5, 8}}, {1, 0, 0, 1, 1, 0}, autob);
check_auto_bcast<op::Less, uint8_t, char>(
{{1, 2, 3, 4, 5, 6}, {1, 5, 8}}, {0, 1, 1, 0, 0, 1}, autob);
check_auto_bcast<op::LessEq, uint8_t, char>(
{{1, 2, 3, 4, 5, 6}, {1, 5, 8}}, {1, 1, 1, 0, 1, 1}, autob);
check_auto_bcast<op::NotEqual, uint8_t, char>(
{{1, 2, 3, 4, 5, 6}, {1, 5, 8}}, {0, 1, 1, 1, 0, 1}, autob);
}
NGRAPH_TEST(${BACKEND_NAME}, auto_bcast_binary_elementwise_pdpd_dynamic)
{
auto pshape_a = PartialShape::dynamic();
auto pshape_b = PartialShape::dynamic();
auto a = make_shared<op::Parameter>(element::f32, pshape_a);
auto b = make_shared<op::Parameter>(element::f32, pshape_b);
op::AutoBroadcastSpec autob = op::AutoBroadcastSpec(op::AutoBroadcastType::PDPD, -1);
auto f = make_shared<Function>(make_shared<op::Add>(a, b, autob), ParameterVector{a, b});
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
auto ex = backend->compile(f);
auto t_r = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
auto t_a = backend->create_tensor(element::f32, Shape{2, 3});
auto t_b = backend->create_tensor(element::f32, Shape{3});
copy_data(t_a, vector<float>{1, 2, 3, 4, 5, 6});
copy_data(t_b, vector<float>{5, 6, 7});
ex->call_with_validate({t_r}, {t_a, t_b});
ASSERT_EQ(t_r->get_shape(), (Shape{2, 3}));
auto results = read_vector<float>(t_r);
vector<float> expected_values{6, 8, 10, 9, 11, 13};
EXPECT_TRUE(test::all_close_f(results, expected_values));
// a shape {2, 3, 4, 5}, b shape {3, 4} axis = 1
autob = op::AutoBroadcastSpec(op::AutoBroadcastType::PDPD, 1);
f = make_shared<Function>(make_shared<op::Add>(a, b, autob), ParameterVector{a, b});
ex = backend->compile(f);
t_r = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
t_a = backend->create_tensor(element::f32, Shape{2, 3, 4, 5});
t_b = backend->create_tensor(element::f32, Shape{3, 4});
copy_data(t_a, vector<float>(2 * 3 * 4 * 5, 1));
copy_data(t_b, vector<float>(3 * 4, 1));
ex->call_with_validate({t_r}, {t_a, t_b});
ASSERT_EQ(t_r->get_shape(), (Shape{2, 3, 4, 5}));
// a shape {2, 3, 4, 5}, b shape {3, 1} axis = 1
t_r = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
t_a = backend->create_tensor(element::f32, Shape{2, 3, 4, 5});
t_b = backend->create_tensor(element::f32, Shape{3, 1});
copy_data(t_a, vector<float>(2 * 3 * 4 * 5, 1));
copy_data(t_b, vector<float>(3, 1));
ex->call_with_validate({t_r}, {t_a, t_b});
ASSERT_EQ(t_r->get_shape(), (Shape{2, 3, 4, 5}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment