Commit 8508410f authored by Chris Sullivan's avatar Chris Sullivan Committed by Scott Cyphers

Add op::Or and op::And to GPU transformer (#979)

* Moved emit_elementwise implementation into CUDAEmitter and added logical_and and logical_or ops.

* Updated comment and formatting.

* Added check for multi-output elementwise ops.
parent 7bc6b785
......@@ -21,6 +21,7 @@
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/runtime/gpu/cuda_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
......@@ -603,6 +604,79 @@ size_t runtime::gpu::CUDAEmitter::build_avg_pool(const GPURuntimeContext* ctx,
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_elementwise_n_to_1(const GPURuntimeContext* ctx,
const std::vector<std::string>& dtypes,
const Shape& tensor_shape,
const char* op,
const char* kernel)
{
// kernel_name is used to check if the cuda kernel has been previously compiled
std::stringstream kernel_name;
kernel_name << "ew"
<< "_" << op << "_" << join(dtypes, "_");
// hash is used to check if the emitted primitive already exists
std::stringstream ss;
ss << kernel_name.str() << "_s" << join(tensor_shape, "_");
auto hash = ss.str();
// if the primitive exists, we are done
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto compiled_kernel = ctx->compiled_kernel_pool->get(kernel_name.str());
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
if (kernel)
{
CudaKernelBuilder::get_device_helper(writer, op, kernel, dtypes);
}
CudaKernelBuilder::get_elementwise_op(writer, kernel_name.str(), op, dtypes);
std::string kernel = writer.get_code();
compiled_kernel = ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
size_t nthreads = shape_size(tensor_shape);
// create the launch primitive
std::unique_ptr<gpu::primitive> ew(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
std::vector<void*> args_list;
for (auto i = 0u; i < dtypes.size() - 1; i++)
{
args_list.push_back(&inputs[i]);
}
args_list.push_back(&outputs[0]);
args_list.push_back(&nthreads);
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(nthreads),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list.data(),
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}});
primitive_index = this->m_primitive_emitter->insert(std::move(ew));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
void runtime::gpu::CUDAEmitter::print_tensor_from_gpu(codegen::CodeWriter& writer,
const std::string& tensor_name,
const Shape& shape)
......
......@@ -18,6 +18,7 @@
#include <array>
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
namespace ngraph
{
......@@ -60,12 +61,26 @@ namespace ngraph
const Shape& padding_below,
bool include_pad = false);
template <typename T>
size_t build_elementwise(const GPURuntimeContext* ctx,
const std::vector<std::string>& dtypes,
const Shape& tensor_shape)
{
return build_elementwise_n_to_1(
ctx, dtypes, tensor_shape, CudaOpMap<T>::op, CudaOpMap<T>::math_kernel);
}
private:
CUDAEmitter(GPUPrimitiveEmitter* emitter);
void print_tensor_from_gpu(codegen::CodeWriter& writer,
const std::string& tensor_name,
const Shape& shape);
std::string include_helpers();
size_t build_elementwise_n_to_1(const GPURuntimeContext* ctx,
const std::vector<std::string>& dtypes,
const Shape& tensor_shape,
const char* op,
const char* kernel);
GPUPrimitiveEmitter* m_primitive_emitter;
};
......
......@@ -23,9 +23,9 @@ using namespace ngraph;
void runtime::gpu::CudaKernelBuilder::get_elementwise_op(codegen::CodeWriter& writer,
const std::string& name,
const std::string& op,
const std::vector<std::string>& data_types,
const size_t& num_inputs)
const std::vector<std::string>& data_types)
{
auto num_inputs = data_types.size() - 1;
writer << "extern \"C\" __global__ void cuda_" << name << "(";
for (size_t i = 0; i < num_inputs; i++)
{
......@@ -245,12 +245,12 @@ void runtime::gpu::CudaKernelBuilder::get_reverse_op(codegen::CodeWriter& writer
void runtime::gpu::CudaKernelBuilder::get_device_helper(codegen::CodeWriter& writer,
const std::string& name,
const std::string& math_kernel,
const std::vector<std::string>& data_types,
const size_t& num_inputs)
const std::vector<std::string>& data_types)
{
if (math_kernel.size())
{
writer << "__device__ " << data_types[num_inputs] << " " << name << "(";
auto num_inputs = data_types.size() - 1;
writer << "__device__ __forceinline__ " << data_types[num_inputs] << " " << name << "(";
for (size_t i = 0; i < num_inputs - 1; i++)
{
writer << data_types[i] << " x" << i << ", ";
......
......@@ -36,8 +36,7 @@ namespace ngraph
static void get_elementwise_op(codegen::CodeWriter& writer,
const std::string& name,
const std::string& op,
const std::vector<std::string>& data_types,
const size_t& num_inputs);
const std::vector<std::string>& data_types);
static void get_broadcast_op(codegen::CodeWriter& writer,
const std::string& name,
......@@ -67,8 +66,7 @@ namespace ngraph
static void get_device_helper(codegen::CodeWriter& writer,
const std::string& name,
const std::string& math_kernel,
const std::vector<std::string>& data_types,
const size_t& num_inputs);
const std::vector<std::string>& data_types);
static void add_pod_typedefs(codegen::CodeWriter& writer);
};
......
......@@ -86,55 +86,6 @@ namespace ngraph
size_t rank,
size_t count);
template <typename T, typename... Inputs>
void emit_elementwise_op(const std::string& name,
const std::vector<std::string>& data_types,
GPURuntimeContext* ctx,
size_t count,
CUdeviceptr out,
Inputs&&... inputs)
{
std::string type_signature = "_" + join(data_types, "_");
std::replace(type_signature.begin(), type_signature.end(), ' ', '_');
auto compiled_kernel = ctx->compiled_kernel_pool->get(name + type_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
std::string op_name = CudaOpMap<T>::op;
if (CudaOpMap<T>::math_kernel)
{
op_name += type_signature;
CudaKernelBuilder::get_device_helper(writer,
op_name,
CudaOpMap<T>::math_kernel,
data_types,
sizeof...(inputs));
}
CudaKernelBuilder::get_elementwise_op(
writer, name + type_signature, op_name, data_types, sizeof...(inputs));
std::string kernel = writer.get_code();
compiled_kernel = ctx->compiled_kernel_pool->set(name + type_signature, kernel);
}
void* args_list[] = {&inputs..., &out, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
count,
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
template <typename... Inputs>
void emit_concat_op(const std::string& name,
const std::vector<std::string>& data_types,
......
......@@ -54,11 +54,16 @@ namespace ngraph
class Not;
class Sqrt;
class Select;
class And;
class Or;
}
namespace runtime
{
namespace gpu
{
template <typename T>
struct CudaOpMap;
template <>
struct CudaOpMap<ngraph::op::Abs>
{
......@@ -265,7 +270,7 @@ namespace ngraph
template <>
struct CudaOpMap<ngraph::op::Not>
{
static constexpr const char* op = "not";
static constexpr const char* op = "logical_not";
static constexpr const char* math_kernel = "!x0";
};
......@@ -282,6 +287,20 @@ namespace ngraph
static constexpr const char* op = "relu_backprop";
static constexpr const char* math_kernel = "x1 * int(x0 > 0)";
};
template <>
struct CudaOpMap<ngraph::op::And>
{
static constexpr const char* op = "logical_and";
static constexpr const char* math_kernel = "x0 & x1";
};
template <>
struct CudaOpMap<ngraph::op::Or>
{
static constexpr const char* op = "logical_or";
static constexpr const char* math_kernel = "x0 | x1";
};
}
}
}
......@@ -106,40 +106,6 @@ namespace ngraph
{
namespace gpu
{
void GPU_Emitter::emit_elementwise(
GPU_ExternalFunction* external_function,
codegen::CodeWriter& writer,
const ngraph::Node* node,
const vector<runtime::gpu::GPU_TensorViewWrapper>& args,
const vector<runtime::gpu::GPU_TensorViewWrapper>& out)
{
if (out[0].get_size() == 0)
{
return;
}
writer.block_begin(" // " + node->get_name());
writer << "int count = " << out[0].get_size() << ";\n";
writer << "if(count == 0) return;\n";
writer << "ngraph::runtime::gpu::emit_elementwise_op<ngraph::op::"
<< node->description() << ">(\"" << node->description() << "\""
<< ", std::vector<std::string>{";
for (size_t i = 0; i < args.size(); i++)
{
writer << "\"" << args[i].get_type() << "\", ";
}
writer << "\"" << out[0].get_type() << "\"}"
<< ", ctx"
<< ", count"
<< ", CUdeviceptr(" << out[0].get_name() << ")";
for (size_t i = 0; i < args.size(); i++)
{
writer << ", CUdeviceptr(" << args[i].get_name() << ")";
}
writer << ");\n";
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::Add)
{
......
......@@ -58,11 +58,47 @@ namespace ngraph
{
}
template <typename T>
static void emit_elementwise(GPU_ExternalFunction* external_function,
codegen::CodeWriter& writer,
const ngraph::Node* node,
const std::vector<GPU_TensorViewWrapper>& args,
const std::vector<GPU_TensorViewWrapper>& out);
const std::vector<GPU_TensorViewWrapper>& out)
{
if (out[0].get_size() == 0)
{
return;
}
else if (out.size() > 1)
{
throw std::runtime_error(
"Multi-output elementwise ops are not currently supported.");
}
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
writer.block_begin(" // " + node->get_name());
{
std::vector<std::string> dtypes;
for (auto& arg : args)
{
dtypes.push_back(arg.get_type());
}
dtypes.push_back(out[0].get_type());
auto ew_index = cuda_emitter->build_elementwise<T>(
external_function->ctx().get(), dtypes, out[0].get_shape());
writer << "gpu::invoke_primitive(ctx, " << ew_index << ", ";
writer << "std::vector<void*>{" << args.front().get_name();
for (size_t i = 1; i < args.size(); i++)
{
writer << ", " << args[i].get_name();
}
writer << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
writer.block_end();
}
};
Shape get_padded_shape(const Shape& input_shape,
const Shape& padding_below,
......
......@@ -42,6 +42,7 @@
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/avg_pool.hpp"
......@@ -77,6 +78,7 @@
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/power.hpp"
......@@ -170,54 +172,55 @@ static const runtime::gpu::OpMap dispatcher{
{TI(ngraph::op::Dot), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Dot>},
{TI(ngraph::op::Multiply), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Multiply>},
{TI(ngraph::op::Parameter), &runtime::gpu::GPU_Emitter::nop},
{TI(ngraph::op::Abs), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Abs), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Abs>},
{TI(ngraph::op::Concat), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Concat>},
{TI(ngraph::op::Divide), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Equal), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Divide), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Divide>},
{TI(ngraph::op::Equal), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Equal>},
{TI(ngraph::op::GetOutputElement),
&runtime::gpu::GPU_Emitter::emit<ngraph::op::GetOutputElement>},
{TI(ngraph::op::Greater), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::GreaterEq), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Less), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::LessEq), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Log), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Greater), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Greater>},
{TI(ngraph::op::GreaterEq),
&runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::GreaterEq>},
{TI(ngraph::op::Less), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Less>},
{TI(ngraph::op::LessEq), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::LessEq>},
{TI(ngraph::op::Log), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Log>},
{TI(ngraph::op::Maximum), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Maximum>},
{TI(ngraph::op::Minimum), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Minimum>},
{TI(ngraph::op::Negative), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Negative>},
{TI(ngraph::op::NotEqual), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Power), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Select), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Subtract), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::NotEqual), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::NotEqual>},
{TI(ngraph::op::Power), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Power>},
{TI(ngraph::op::Select), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Select>},
{TI(ngraph::op::Subtract), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Subtract>},
{TI(ngraph::op::Broadcast), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Broadcast>},
{TI(ngraph::op::Convert), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Convert), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Convert>},
{TI(ngraph::op::Constant), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Constant>},
{TI(ngraph::op::Reshape), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Reshape>},
{TI(ngraph::op::FunctionCall), &runtime::gpu::GPU_Emitter::emit<ngraph::op::FunctionCall>},
{TI(ngraph::op::Reduce), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Reduce>},
{TI(ngraph::op::Sign), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Sign), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Sign>},
{TI(ngraph::op::Slice), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Slice>},
{TI(ngraph::op::Sum), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Sum>},
{TI(ngraph::op::Exp), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Sin), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Sinh), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Cos), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Cosh), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Tan), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Tanh), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Asin), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Acos), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Atan), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Exp), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Exp>},
{TI(ngraph::op::Sin), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Sin>},
{TI(ngraph::op::Sinh), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Sinh>},
{TI(ngraph::op::Cos), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Cos>},
{TI(ngraph::op::Cosh), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Cosh>},
{TI(ngraph::op::Tan), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Tan>},
{TI(ngraph::op::Tanh), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Tanh>},
{TI(ngraph::op::Asin), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Asin>},
{TI(ngraph::op::Acos), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Acos>},
{TI(ngraph::op::Atan), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Atan>},
{TI(ngraph::op::ReplaceSlice), &runtime::gpu::GPU_Emitter::emit<ngraph::op::ReplaceSlice>},
{TI(ngraph::op::OneHot), &runtime::gpu::GPU_Emitter::emit<ngraph::op::OneHot>},
{TI(ngraph::op::Floor), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Ceiling), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Floor), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Floor>},
{TI(ngraph::op::Ceiling), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Ceiling>},
{TI(ngraph::op::Sqrt), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Sqrt>},
{TI(ngraph::op::Convolution), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::gpu::GPU_Emitter::emit<ngraph::op::ConvolutionBackpropFilters>},
{TI(ngraph::op::ConvolutionBackpropData),
&runtime::gpu::GPU_Emitter::emit<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::Not), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Not), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Not>},
{TI(ngraph::op::MaxPool), &runtime::gpu::GPU_Emitter::emit<ngraph::op::MaxPool>},
{TI(ngraph::op::Reverse), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Reverse>},
{TI(ngraph::op::Result), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Result>},
......@@ -236,10 +239,12 @@ static const runtime::gpu::OpMap dispatcher{
{TI(ngraph::op::Product), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Product>},
{TI(ngraph::op::Max), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Max>},
{TI(ngraph::op::Min), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Min>},
{TI(ngraph::op::Relu), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::ReluBackprop), &runtime::gpu::GPU_Emitter::emit_elementwise},
{TI(ngraph::op::Relu), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Relu>},
{TI(ngraph::op::ReluBackprop),
&runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::ReluBackprop>},
{TI(ngraph::op::Softmax), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Softmax>},
};
{TI(ngraph::op::And), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::And>},
{TI(ngraph::op::Or), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Or>}};
runtime::gpu::GPU_ExternalFunction::GPU_ExternalFunction(
const shared_ptr<ngraph::Function>& function, bool release_function)
......
......@@ -8053,7 +8053,6 @@ TEST(${BACKEND_NAME}, validate_call_output_shape)
TEST(${BACKEND_NAME}, logical_and)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::boolean, shape);
auto B = make_shared<op::Parameter>(element::boolean, shape);
......@@ -8074,7 +8073,6 @@ TEST(${BACKEND_NAME}, logical_and)
TEST(${BACKEND_NAME}, logical_or)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::boolean, shape);
auto B = make_shared<op::Parameter>(element::boolean, shape);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment