import os from skimage import measure import numpy as np def post_process_mask(label, p=0.5): return (np.where(label > p, 1, 0)).astype('uint8') def get_file_path_list(path): full_path_list = [] for path, subdirs, files in os.walk(path): for filename in files: f = os.path.join(path, filename) full_path_list.append(f) return full_path_list def extract_positive_windows(x, shape=(256, 256)): props = measure.regionprops(measure.label(x)) valid_boxes = [] for prop in props: bb = prop.bbox valid_box = np.clip(np.array(bb) + np.array([-(shape[0]-1), -(shape[1]-1), (shape[0]-1), (shape[1]-1)]), 0, None) valid_box[2] = np.clip(valid_box[2], None, 819) valid_box[3] = np.clip(valid_box[3], None, 549) valid_boxes.append(valid_box) return valid_boxes def extract_boxes(X, Y, shape=(256, 256)): boxes = [] for y in Y: boxes.append(extract_positive_windows(y, shape)) X_boxes = [] Y_boxes = [] for i, x in enumerate(X): x_boxes = [] y_boxes = [] for box in boxes[i]: x_boxes.append(x[:,box[0]:box[2],box[1]:box[3]]) y_boxes.append(Y[i,:,box[0]:box[2],box[1]:box[3]]) X_boxes += x_boxes Y_boxes += y_boxes return X_boxes, Y_boxes def rand_crop(X, Y, shape): assert X.shape[1] > shape[0] and X.shape[2] > shape[1] delta_row = X.shape[1] - shape[0] delta_col = X.shape[2] - shape[1] start_row = np.random.randint(delta_row) start_col = np.random.randint(delta_col) return X[:, start_row:start_row+shape[0], start_col:start_col+shape[1]], Y[:, start_row:start_row+shape[0], start_col:start_col+shape[1]] def generate_random_crops(X, Y, num_patches=10, crop_size=(256, 256)): X_crops = [] Y_crops = [] for i, x in enumerate(X): if x.shape[1] < crop_size[0] or x.shape[2] < crop_size[1]: continue for j in range(num_patches): x_crop, y_crop = rand_crop(x, Y[i], crop_size) if np.random.rand() > 0.5: x_crop = x_crop[:,:,::-1] y_crop = y_crop[:,:,::-1] X_crops.append(x_crop) Y_crops.append(y_crop) return np.stack(X_crops), np.stack(Y_crops)