Skip to content

Commit c6a650f

Browse files
authored
fix corner cases of static padding
1 parent 269b356 commit c6a650f

1 file changed

Lines changed: 4 additions & 14 deletions

File tree

efficientnet/utils_extra.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True,
3333
def forward(self, x):
3434
h, w = x.shape[-2:]
3535

36-
h_step = math.ceil(w / self.stride[1])
37-
v_step = math.ceil(h / self.stride[0])
38-
h_cover_len = self.stride[1] * (h_step - 1) + 1 + (self.kernel_size[1] - 1)
39-
v_cover_len = self.stride[0] * (v_step - 1) + 1 + (self.kernel_size[0] - 1)
40-
41-
extra_h = h_cover_len - w
42-
extra_v = v_cover_len - h
36+
extra_h = (w - 1) * self.stride[1] - w + self.kernel_size[1]
37+
extra_v = (h - 1) * self.stride[0] - h + self.kernel_size[0]
4338

4439
left = extra_h // 2
4540
right = extra_h - left
@@ -77,13 +72,8 @@ def __init__(self, *args, **kwargs):
7772
def forward(self, x):
7873
h, w = x.shape[-2:]
7974

80-
h_step = math.ceil(w / self.stride[1])
81-
v_step = math.ceil(h / self.stride[0])
82-
h_cover_len = self.stride[1] * (h_step - 1) + 1 + (self.kernel_size[1] - 1)
83-
v_cover_len = self.stride[0] * (v_step - 1) + 1 + (self.kernel_size[0] - 1)
84-
85-
extra_h = h_cover_len - w
86-
extra_v = v_cover_len - h
75+
extra_h = (w - 1) * self.stride[1] - w + self.kernel_size[1]
76+
extra_v = (h - 1) * self.stride[0] - h + self.kernel_size[0]
8777

8878
left = extra_h // 2
8979
right = extra_h - left

0 commit comments

Comments
 (0)