Unverified Commit f7a34a02 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Properly setting OC for Group Convolution (#1161)

* group conv fix

* group conv fix

* fix typo
parent bb06c80b
......@@ -60,11 +60,12 @@ Shape op::GroupConvolution::get_weights_dimensions() const
{
//reshape weights into 5d tensors that includes groups
const size_t OC = 0;
const size_t OC_IN_OUTPUT = 1;
const size_t IC = 1;
Shape weights_shape_groups{get_inputs().at(1).get_shape()};
//adjust output and channel given a number of groups
weights_shape_groups.at(OC) /= get_groups();
weights_shape_groups.at(OC) = get_shape().at(OC_IN_OUTPUT) / get_groups();
weights_shape_groups.at(IC) = get_inputs().at(0).get_shape().at(IC) / get_groups();
//push_front the number of groups
weights_shape_groups.insert(weights_shape_groups.begin(), get_groups());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment