@@ -93,9 +93,9 @@ class ConvOperator : public Operator {
9393 bool caffeMode_;
9494 int inputOffset_, outputOffset_, weightOffset_;
9595 int numFilters_;
96- int padding_, stride_, filterSize_, channels_, imgSize_;
96+ int padding_, stride_, filterSize_, channels_, imgSize_, imgSizeY_ ;
9797 int paddingY_, strideY_, filterSizeY_;
98- int imgPixels_, filterPixels_, filterChannels_, outputX_, outputs_;
98+ int imgPixels_, filterPixels_, filterChannels_, outputX_, outputY_, outputs_;
9999
100100 // / Following member variables are same with CudnnConvLayer.
101101 // / There is no explanation here.
@@ -144,7 +144,7 @@ void ConvOperator::allocConvWorkSpace(size_t maxWorkSpace) {
144144void ConvOperator::reshape (int batchSize) {
145145 imageH_ = ins_[0 ]->getFrameHeight ();
146146 imageW_ = ins_[0 ]->getFrameWidth ();
147- if (imageH_ == 0 ) imageH_ = imgSize_ ;
147+ if (imageH_ == 0 ) imageH_ = imgSizeY_ ;
148148 if (imageW_ == 0 ) imageW_ = imgSize_;
149149 outputH_ = outputSize (imageH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
150150 outputW_ = outputSize (imageW_, filterSize_, padding_, stride_, caffeMode_);
@@ -182,7 +182,10 @@ void ConvOperator::computeConvSizes() {
182182 hl_create_tensor_descriptor (&inputDesc_);
183183 int outputX =
184184 outputSize (imgSize_, filterSize_, padding_, stride_, caffeMode_);
185+ int outputY =
186+ outputSize (imgSizeY_, filterSizeY_, paddingY_, strideY_, caffeMode_);
185187 CHECK_EQ (outputX, outputX_);
188+ CHECK_EQ (outputY, outputY_);
186189 hl_create_tensor_descriptor (&outputDesc_);
187190 hl_create_convolution_descriptor (&convDesc_,
188191 inputDesc_,
@@ -236,10 +239,12 @@ void ConvOperator::getConvParams() {
236239 filterPixels_ = filterSize_ * filterSizeY_;
237240 channels_ = conf.channels ();
238241 imgSize_ = conf.img_size ();
239- imgPixels_ = imgSize_ * imgSize_;
242+ imgSizeY_ = conf.has_img_size_y () ? conf.img_size_y () : conf.img_size ();
243+ imgPixels_ = imgSize_ * imgSizeY_;
240244 CHECK_EQ (conf.groups (), 1U );
241245 filterChannels_ = conf.filter_channels ();
242246 outputX_ = conf.output_x ();
247+ outputY_ = conf.has_output_y () ? conf.output_y () : conf.output_x ();
243248 outputs_ = outputX_ * outputX_;
244249}
245250
0 commit comments