Skip to content

Commit

Permalink
Strengthen sample efficiency of locate bbx functions
Browse files Browse the repository at this point in the history
  • Loading branch information
billhhh authored Sep 22, 2024
1 parent b645c82 commit bcd8a14
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions BraTSDataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ def locate_bbx(self, label):
class_num, img_d, img_h, img_w = label.shape

if random.random() < 0.5:
selected_class = np.random.choice(class_num + 1)
selected_class = np.random.choice(class_num)
class_locs = []
if selected_class != class_num:
class_label = label[selected_class]
class_locs = np.argwhere(class_label > 0)

if selected_class == class_num or len(class_locs) == 0:
if len(class_locs) == 0:
# if no foreground found, then randomly select
d0 = random.randint(0, img_d - 0 - self.crop_d)
h0 = random.randint(15, img_h - 15 - self.crop_h)
Expand Down Expand Up @@ -149,8 +149,6 @@ def locate_bbx_wScale(self, label):
scale_flag = False
if self.scale and np.random.uniform() < 0.5:
scaler = np.random.uniform(0.9, 1.1)
# if self.scale and np.random.uniform() < 0.2:
# scaler = np.random.uniform(0.85, 1.25)
scale_flag = True
else:
scaler = 1
Expand All @@ -161,13 +159,13 @@ def locate_bbx_wScale(self, label):
class_num, img_d, img_h, img_w = label.shape

if random.random() < 0.5:
selected_class = np.random.choice(class_num + 1)
selected_class = np.random.choice(class_num)
class_locs = []
if selected_class != class_num:
class_label = label[selected_class]
class_locs = np.argwhere(class_label > 0)

if selected_class == class_num or len(class_locs) == 0:
if len(class_locs) == 0:
# if no foreground found, then randomly select
d0 = random.randint(0, img_d - 0 - scale_d)
h0 = random.randint(15, img_h - 15 - scale_h)
Expand Down

0 comments on commit bcd8a14

Please sign in to comment.