Commit d19d1271 authored by Amy Zhuang's avatar Amy Zhuang Committed by Michał Karzyński

Modify slice layout. (#1788)

parent b4338a52
...@@ -89,6 +89,7 @@ namespace ngraph ...@@ -89,6 +89,7 @@ namespace ngraph
protected: protected:
/// Throws if the node is invalid. /// Throws if the node is invalid.
virtual void validate_and_infer_types();
// Called in constructors during transition // Called in constructors during transition
void constructor_validate_and_infer_types(); void constructor_validate_and_infer_types();
...@@ -106,7 +107,7 @@ namespace ngraph ...@@ -106,7 +107,7 @@ namespace ngraph
virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {} virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {}
public: public:
virtual void validate_and_infer_types(); void revalidate_and_infer_types() { validate_and_infer_types(); }
// Called after transition // Called after transition
void delayed_validate_and_infer_types(); void delayed_validate_and_infer_types();
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "ngraph/runtime/cpu/kernel/softmax.hpp" #include "ngraph/runtime/cpu/kernel/softmax.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp" #include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/reference/softmax.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -131,8 +132,35 @@ namespace ngraph ...@@ -131,8 +132,35 @@ namespace ngraph
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
else if (arg_shape.size() == 4 && axes.size() == 3)
{
std::function<decltype(runtime::cpu::kernel::softmax_4d_3rd<float>)> kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::softmax_4d_3rd);
auto functor = [&, kernel, arg_shape, axes](CPURuntimeContext* ctx) {
kernel(arg_tensor, out_tensor, arg_shape, axes);
};
functors.emplace_back(functor);
}
else if (softmax->get_element_type() == element::f32)
{
NGRAPH_WARN << "Falling back to refernce kernel for softmax " << arg_shape
<< " over " << axes;
auto functor = [&, arg_shape, axes](CPURuntimeContext* ctx) {
runtime::reference::softmax<float>(static_cast<float*>(arg_tensor),
static_cast<float*>(out_tensor),
arg_shape,
axes);
};
functors.emplace_back(functor);
}
else else
{ {
NGRAPH_ERR << "Unsupported Softmax " << arg_shape << " over " << axes
<< " in cpu buiilder";
throw ngraph_error("Unsupported Softmax"); throw ngraph_error("Unsupported Softmax");
} }
} }
......
...@@ -126,6 +126,7 @@ ...@@ -126,6 +126,7 @@
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp" #include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/nop_elimination.hpp" #include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/aligned_buffer.hpp" #include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/cpu/cpu_backend.hpp" #include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp" #include "ngraph/runtime/cpu/cpu_builder.hpp"
...@@ -1001,6 +1002,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma ...@@ -1001,6 +1002,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
{ {
pass_manager.register_pass<ngraph::pass::LikeReplacement>(); pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::NopElimination>(); pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
// TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing // TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing
// failing mxnet unit tests. // failing mxnet unit tests.
// pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>(); // pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
...@@ -1013,7 +1015,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma ...@@ -1013,7 +1015,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>(); pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>(); pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUHorizontalFusion>(); // pass_manager.register_pass<runtime::cpu::pass::CPUHorizontalFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>(); pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>();
NodeVector nv_cwi; // We dont need CPUWorkspaceInsertion to return list of indices NodeVector nv_cwi; // We dont need CPUWorkspaceInsertion to return list of indices
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi, false); pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi, false);
......
...@@ -147,6 +147,15 @@ namespace ngraph ...@@ -147,6 +147,15 @@ namespace ngraph
{ {
softmax<ElementType, 3, 2>(input, output, input_shape, softmax_axes); softmax<ElementType, 3, 2>(input, output, input_shape, softmax_axes);
} }
template <typename ElementType>
void softmax_4d_3rd(void* input,
void* output,
const Shape& input_shape,
const AxisSet& softmax_axes)
{
softmax<ElementType, 4, 3>(input, output, input_shape, softmax_axes);
}
} }
} }
} }
......
...@@ -529,7 +529,7 @@ memory::desc runtime::cpu::mkldnn_utils::expand_blocked_md(const memory::desc& i ...@@ -529,7 +529,7 @@ memory::desc runtime::cpu::mkldnn_utils::expand_blocked_md(const memory::desc& i
size_t k = 0; size_t k = 0;
for (size_t i = 0, j = 0; j < md.ndims; j++) for (size_t i = 0, j = 0; j < md.ndims; j++)
{ {
if (j == axis_list[k]) if (k < axis_list.size() && j == axis_list[k])
{ {
k++; k++;
md.dims[j] = 1; md.dims[j] = 1;
...@@ -545,7 +545,8 @@ memory::desc runtime::cpu::mkldnn_utils::expand_blocked_md(const memory::desc& i ...@@ -545,7 +545,8 @@ memory::desc runtime::cpu::mkldnn_utils::expand_blocked_md(const memory::desc& i
} }
else else
{ {
md.layout_desc.blocking.strides[1][j] = 0; md.layout_desc.blocking.strides[1][j] =
in.data.layout_desc.blocking.strides[0][in.data.ndims - 1];
size_t nelems = 1; size_t nelems = 1;
for (size_t idx = 0; idx < in.data.ndims; idx++) for (size_t idx = 0; idx < in.data.ndims; idx++)
nelems *= in.data.dims[idx]; nelems *= in.data.dims[idx];
......
...@@ -124,7 +124,7 @@ void ngraph::runtime::cpu::pass::CPUHorizontalFusion::cpu_conv_horizontal_fusion ...@@ -124,7 +124,7 @@ void ngraph::runtime::cpu::pass::CPUHorizontalFusion::cpu_conv_horizontal_fusion
NGRAPH_DEBUG << "conv_horizontal_fusion: slice shape " << slice_shape << "\n"; NGRAPH_DEBUG << "conv_horizontal_fusion: slice shape " << slice_shape << "\n";
auto lower_bounds = Coordinate{0, index, 0, 0}; auto lower_bounds = Coordinate{0, index, 0, 0};
index += slice_shape[1]; index += slice_shape[1];
auto upper_bounds = Coordinate{slice_shape[0], index, slice_shape[2], slice_shape[2]}; auto upper_bounds = Coordinate{slice_shape[0], index, slice_shape[2], slice_shape[3]};
NGRAPH_DEBUG << "conv_horizontal_fusion: lower_bounds " << lower_bounds << "\n"; NGRAPH_DEBUG << "conv_horizontal_fusion: lower_bounds " << lower_bounds << "\n";
NGRAPH_DEBUG << "conv_horizontal_fusion: upper_bounds " << upper_bounds << "\n"; NGRAPH_DEBUG << "conv_horizontal_fusion: upper_bounds " << upper_bounds << "\n";
auto slice = auto slice =
......
...@@ -1533,7 +1533,18 @@ namespace ngraph ...@@ -1533,7 +1533,18 @@ namespace ngraph
} }
else else
{ {
set_native_layouts(external_function, node); if (mkldnn_utils::get_input_mkldnn_md(node.get(), 0).data.format ==
mkldnn_format_undef)
{
set_native_layouts(external_function, node);
}
else
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
vector<memory::desc> o_mds;
o_mds.push_back(input_md);
set_output_layouts(node, o_mds);
}
} }
} }
...@@ -1775,33 +1786,36 @@ namespace ngraph ...@@ -1775,33 +1786,36 @@ namespace ngraph
auto result_shape = slice->get_output_shape(0); auto result_shape = slice->get_output_shape(0);
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0); auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
auto input_pd = mkldnn::memory::primitive_desc( NGRAPH_DEBUG << "input memory format: " << input_md.data.format << "\n";
input_md, runtime::cpu::mkldnn_utils::global_cpu_engine); auto result_format =
auto dims = mkldnn::memory::dims(result_shape.begin(), result_shape.end()); static_cast<mkldnn::memory::format>(input_md.data.format);
auto offsets =
mkldnn::memory::dims(lower_bounds.begin(), lower_bounds.end());
try // check lower bounds and output shape
{ for (auto i = 0; i < input_md.data.ndims; i++)
// MKLDNN currently doesn't support views for blocked layouts
// when the dims and offsets are not divisible by the block size
auto view_md = mkldnn::view::primitive_desc(input_pd, dims, offsets)
.dst_primitive_desc()
.desc();
vector<memory::desc> o_mds;
o_mds.push_back(view_md);
set_output_layouts(node, o_mds);
}
catch (const mkldnn::error& e)
{ {
if (e.status == mkldnn_unimplemented) auto block_size = input_md.data.layout_desc.blocking.block_dims[i];
if (block_size != 0 && (lower_bounds[i] % block_size != 0 ||
result_shape[i] % block_size != 0))
{ {
NGRAPH_DEBUG << "slice: number of channels in lower bounds or "
"output shape is not multiple of block size, "
"set native layout\n";
set_native_layouts(external_function, node); set_native_layouts(external_function, node);
return;
} }
else }
{
throw ngraph_error(e.message); if (result_format == mkldnn::memory::blocked)
} {
set_native_layouts(external_function, node);
}
else
{
vector<memory::desc> o_mds;
auto result_desc = mkldnn_utils::create_default_mkldnn_md(
node.get(), 0, true, result_format);
o_mds.push_back(result_desc);
set_output_layouts(node, o_mds);
} }
} }
else else
......
...@@ -347,7 +347,7 @@ bool ngraph::runtime::cpu::pass::CPUReshapeSinking::run_on_function( ...@@ -347,7 +347,7 @@ bool ngraph::runtime::cpu::pass::CPUReshapeSinking::run_on_function(
//fix wrong shape info wholesale //fix wrong shape info wholesale
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
{ {
n->validate_and_infer_types(); n->revalidate_and_infer_types();
} }
return true; return true;
} }
...@@ -890,6 +890,7 @@ TEST(cpu_fusion, conv_bias_relu_n2c1h2w2_2) ...@@ -890,6 +890,7 @@ TEST(cpu_fusion, conv_bias_relu_n2c1h2w2_2)
EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0))); EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0)));
} }
#if 0
TEST(cpu_fusion, conv_horizontal_fusion) TEST(cpu_fusion, conv_horizontal_fusion)
{ {
Shape shape_a{2, 1, 6, 6}; Shape shape_a{2, 1, 6, 6};
...@@ -940,6 +941,7 @@ TEST(cpu_fusion, conv_horizontal_fusion) ...@@ -940,6 +941,7 @@ TEST(cpu_fusion, conv_horizontal_fusion)
size_t cpu_cb = count_ops_of_type<op::ConvolutionBias>(cpu_f); size_t cpu_cb = count_ops_of_type<op::ConvolutionBias>(cpu_f);
ASSERT_EQ(cpu_cb, 1); ASSERT_EQ(cpu_cb, 1);
} }
#endif
// ConvolutionBiasAdd relies on an in-place fused MKLDNN kernel. // ConvolutionBiasAdd relies on an in-place fused MKLDNN kernel.
// Need to ensure that it is fused only when in-place buffer allocation is feasible // Need to ensure that it is fused only when in-place buffer allocation is feasible
......
...@@ -195,8 +195,9 @@ TEST(cpu_test, mkldnn_layouts) ...@@ -195,8 +195,9 @@ TEST(cpu_test, mkldnn_layouts)
EXPECT_EQ(vector<float>{expected_result}, rv); EXPECT_EQ(vector<float>{expected_result}, rv);
} }
TEST(cpu_test, reshape_squeeze) TEST(cpu_test, reshape_layout_optimizations1)
{ {
// Squeeze outermost dimension
auto make_function = []() -> std::shared_ptr<Function> { auto make_function = []() -> std::shared_ptr<Function> {
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 16, 2, 2}); auto A = make_shared<op::Parameter>(element::f32, Shape{1, 16, 2, 2});
auto B = make_shared<op::Parameter>(element::f32, Shape{32, 16, 1, 1}); auto B = make_shared<op::Parameter>(element::f32, Shape{32, 16, 1, 1});
...@@ -233,8 +234,9 @@ TEST(cpu_test, reshape_squeeze) ...@@ -233,8 +234,9 @@ TEST(cpu_test, reshape_squeeze)
} }
} }
TEST(cpu_test, reshape_expand) TEST(cpu_test, reshape_layout_optimizations2)
{ {
// ExpandDims - inner most and internal dims
auto make_function = []() -> std::shared_ptr<Function> { auto make_function = []() -> std::shared_ptr<Function> {
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 16, 2, 2}); auto A = make_shared<op::Parameter>(element::f32, Shape{1, 16, 2, 2});
auto B = make_shared<op::Parameter>(element::f32, Shape{32, 16, 1, 1}); auto B = make_shared<op::Parameter>(element::f32, Shape{32, 16, 1, 1});
...@@ -271,8 +273,9 @@ TEST(cpu_test, reshape_expand) ...@@ -271,8 +273,9 @@ TEST(cpu_test, reshape_expand)
} }
} }
TEST(cpu_test, reshape_squeeze_padded) TEST(cpu_test, reshape_layout_optimizations3)
{ {
// Squeeze padded dimension
auto make_function = []() -> std::shared_ptr<Function> { auto make_function = []() -> std::shared_ptr<Function> {
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 16, 2, 2}); auto A = make_shared<op::Parameter>(element::f32, Shape{1, 16, 2, 2});
auto B = make_shared<op::Parameter>(element::f32, Shape{1, 16, 1, 1}); auto B = make_shared<op::Parameter>(element::f32, Shape{1, 16, 1, 1});
...@@ -310,8 +313,9 @@ TEST(cpu_test, reshape_squeeze_padded) ...@@ -310,8 +313,9 @@ TEST(cpu_test, reshape_squeeze_padded)
} }
} }
TEST(cpu_test, reshape_expand_squeeze) TEST(cpu_test, reshape_layout_optimizations4)
{ {
// Squeeze and expand dimensions. Ensure no extra conversions downstream
auto make_function = []() -> std::shared_ptr<Function> { auto make_function = []() -> std::shared_ptr<Function> {
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 16, 1, 8}); auto A = make_shared<op::Parameter>(element::f32, Shape{1, 16, 1, 8});
auto B1 = make_shared<op::Parameter>(element::f32, Shape{32, 16, 1, 1}); auto B1 = make_shared<op::Parameter>(element::f32, Shape{32, 16, 1, 1});
...@@ -322,7 +326,7 @@ TEST(cpu_test, reshape_expand_squeeze) ...@@ -322,7 +326,7 @@ TEST(cpu_test, reshape_expand_squeeze)
CoordinateDiff{0, 0}, CoordinateDiff{0, 0},
CoordinateDiff{0, 0}, CoordinateDiff{0, 0},
Strides{1, 1}); Strides{1, 1});
auto squeeze = make_shared<op::Reshape>(conv1, AxisVector{0, 1, 2, 3}, Shape{1, 32, 8}); auto squeeze = make_shared<op::Reshape>(conv1, AxisVector{0, 1, 2, 3}, Shape{32, 1, 8});
auto relu = make_shared<op::Relu>(squeeze); auto relu = make_shared<op::Relu>(squeeze);
auto expand = make_shared<op::Reshape>(relu, AxisVector{0, 1, 2}, Shape{1, 32, 1, 8}); auto expand = make_shared<op::Reshape>(relu, AxisVector{0, 1, 2}, Shape{1, 32, 1, 8});
auto B2 = make_shared<op::Parameter>(element::f32, Shape{8, 32, 1, 1}); auto B2 = make_shared<op::Parameter>(element::f32, Shape{8, 32, 1, 1});
...@@ -357,3 +361,120 @@ TEST(cpu_test, reshape_expand_squeeze) ...@@ -357,3 +361,120 @@ TEST(cpu_test, reshape_expand_squeeze)
} }
EXPECT_LE(count_ops_of_type<runtime::cpu::op::ConvertLayout>(cpu_f), 4); EXPECT_LE(count_ops_of_type<runtime::cpu::op::ConvertLayout>(cpu_f), 4);
} }
TEST(cpu_test, reshape_layout_optimizations5)
{
auto make_function = []() -> std::shared_ptr<Function> {
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 16, 1, 8});
auto B1 = make_shared<op::Parameter>(element::f32, Shape{32, 16, 1, 1});
auto conv1 = make_shared<op::Convolution>(A,
B1,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto expand =
make_shared<op::Reshape>(conv1, AxisVector{0, 1, 2, 3}, Shape{1, 1, 32, 1, 8});
auto relu = make_shared<op::Relu>(expand);
auto squeeze =
make_shared<op::Reshape>(relu, AxisVector{0, 1, 2, 3, 4}, Shape{1, 32, 1, 8});
auto B2 = make_shared<op::Parameter>(element::f32, Shape{8, 32, 1, 1});
auto conv2 = make_shared<op::Convolution>(squeeze,
B2,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
return make_shared<Function>(NodeVector{conv2}, op::ParameterVector{A, B1, B2});
};
auto backend = runtime::Backend::create("CPU");
auto cpu_f = make_function();
auto int_f = make_function();
test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
EXPECT_LE(count_ops_of_type<runtime::cpu::op::ConvertLayout>(cpu_f), 4);
}
TEST(cpu_test, reshape_layout_optimizations6)
{
// Squeeze and expand dimensions. Ensure no extra conversions downstream
auto make_function = []() -> std::shared_ptr<Function> {
auto A = make_shared<op::Parameter>(element::f32, Shape{2, 4, 3, 2});
auto mul = make_shared<op::Multiply>(A, A);
auto sum = make_shared<op::Sum>(mul, AxisVector{0});
auto reshape = make_shared<op::Reshape>(sum, AxisVector{0, 1, 2}, Shape{1, 4, 3, 2});
auto sqrt = make_shared<op::Sqrt>(reshape);
return make_shared<Function>(NodeVector{sqrt}, op::ParameterVector{A});
};
auto backend = runtime::Backend::create("CPU");
auto cpu_f = make_function();
auto int_f = make_function();
test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i)));
}
EXPECT_EQ(count_ops_of_type<runtime::cpu::op::ConvertLayout>(cpu_f), 0);
}
TEST(cpu_test, reshape_layout_optimizations7)
{
// Expand multiple dimensions. Ensure no extra conversions downstream
auto make_function = []() -> std::shared_ptr<Function> {
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 4, 10, 6, 10});
auto mul = make_shared<op::Multiply>(A, A);
auto sum = make_shared<op::Sum>(mul, AxisVector{0, 1});
auto reshape = make_shared<op::Reshape>(sum, AxisVector{0, 1, 2}, Shape{1, 1, 10, 6, 10});
return make_shared<Function>(NodeVector{reshape}, op::ParameterVector{A});
};
auto backend = runtime::Backend::create("CPU");
auto cpu_f = make_function();
auto int_f = make_function();
test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i)));
}
EXPECT_EQ(count_ops_of_type<runtime::cpu::op::ConvertLayout>(cpu_f), 0);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment