Unverified Commit 6ed9e990 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into ayzhuang/cf_gather

parents 37849092 f6a404eb
...@@ -60,6 +60,8 @@ ExternalProject_Add( ...@@ -60,6 +60,8 @@ ExternalProject_Add(
${GTEST_CMAKE_ARGS} ${GTEST_CMAKE_ARGS}
BINARY_DIR "${EXTERNAL_PROJECTS_ROOT}/gtest/build" BINARY_DIR "${EXTERNAL_PROJECTS_ROOT}/gtest/build"
EXCLUDE_FROM_ALL TRUE EXCLUDE_FROM_ALL TRUE
BUILD_BYPRODUCTS ${CMAKE_BINARY_DIR}/ngraph/gtest/build/googlemock/gtest/libgtest.a
BUILD_BYPRODUCTS ${CMAKE_BINARY_DIR}/ngraph/gtest/build/googlemock/libgmock.a
) )
#------------------------------------------------------------------------------ #------------------------------------------------------------------------------
......
...@@ -66,12 +66,18 @@ shared_ptr<Node> op::Gelu::copy_with_new_args(const NodeVector& new_args) const ...@@ -66,12 +66,18 @@ shared_ptr<Node> op::Gelu::copy_with_new_args(const NodeVector& new_args) const
void op::Gelu::pre_validate_and_infer_types() void op::Gelu::pre_validate_and_infer_types()
{ {
element::Type input_element_type = get_input_element_type(0); element::Type input_element_type = get_input_element_type(0);
PartialShape input_pshape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(), input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ", "Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type, input_element_type,
")."); ").");
if (input_pshape.is_dynamic())
{
set_output_type(0, input_element_type, input_pshape);
}
} }
void op::Gelu::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::Gelu::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
...@@ -94,12 +100,18 @@ op::GeluBackpropFactor::GeluBackpropFactor(const Output<Node>& x) ...@@ -94,12 +100,18 @@ op::GeluBackpropFactor::GeluBackpropFactor(const Output<Node>& x)
void op::GeluBackpropFactor::pre_validate_and_infer_types() void op::GeluBackpropFactor::pre_validate_and_infer_types()
{ {
element::Type input_element_type = get_input_element_type(0); element::Type input_element_type = get_input_element_type(0);
PartialShape input_pshape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(), input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ", "Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type, input_element_type,
")."); ").");
if (input_pshape.is_dynamic())
{
set_output_type(0, input_element_type, input_pshape);
}
} }
shared_ptr<Node> op::GeluBackpropFactor::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::GeluBackpropFactor::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -100,6 +100,10 @@ void op::GroupConvolution::pre_validate_and_infer_types() ...@@ -100,6 +100,10 @@ void op::GroupConvolution::pre_validate_and_infer_types()
get_groups()) == data_shape.to_shape()[1], get_groups()) == data_shape.to_shape()[1],
"Incorrect number of channels per filter"); "Incorrect number of channels per filter");
} }
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
} }
void op::GroupConvolution::post_validate_and_infer_types() void op::GroupConvolution::post_validate_and_infer_types()
......
...@@ -170,7 +170,7 @@ shared_ptr<Node> op::LayerNorm::copy_with_new_args(const NodeVector& new_args) c ...@@ -170,7 +170,7 @@ shared_ptr<Node> op::LayerNorm::copy_with_new_args(const NodeVector& new_args) c
} }
} }
void op::LayerNorm::pre_validate_and_infer_types() void op::LayerNorm::validate_and_infer_types()
{ {
element::Type input_element_type = get_input_element_type(0); element::Type input_element_type = get_input_element_type(0);
...@@ -509,7 +509,7 @@ shared_ptr<Node> op::LayerNormBackprop::copy_with_new_args(const NodeVector& new ...@@ -509,7 +509,7 @@ shared_ptr<Node> op::LayerNormBackprop::copy_with_new_args(const NodeVector& new
} }
} }
void op::LayerNormBackprop::pre_validate_and_infer_types() void op::LayerNormBackprop::validate_and_infer_types()
{ {
element::Type input_element_type = get_input_element_type(0); element::Type input_element_type = get_input_element_type(0);
......
...@@ -55,7 +55,7 @@ namespace ngraph ...@@ -55,7 +55,7 @@ namespace ngraph
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -121,7 +121,7 @@ namespace ngraph ...@@ -121,7 +121,7 @@ namespace ngraph
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment