# Salt Body Interpretation on Seismic Using Sagemaker and MXNET

This notebook contains a tutorial on how to build a deep learning (semantic segmentation) model for automatic salt interpretation. 
* The fully convolutional architecture known as UNet for semantic segmentation
* How to train UNet in Amazon SageMaker, and deploy to an inference endpoint

Import the following modules:

## IMPORTANT

Depending on when you started your Sagemaker notebook, you might need to re-install "scikit-image", "scikit-learn", "numpy" and "scipy" libraries as some functions used in this notebook may not be available in other versions. If you get errors about using any of these 4 libraries, you can un-comment below 4 lines of code and run it.

In [None]:
# !pip install scikit-learn==0.16.0
# !pip install scikit-image==0.12.2
# !pip install scipy==1.2.1   
# !pip install numpy==1.16.4  

In [None]:
import mxnet as mx
import segmentation_methods as sgm
from mxnet import ndarray as F
import numpy as np
import urllib
from PIL import Image
np.random.seed(1984)
import glob
import os
import urllib
import zipfile
from scipy.misc import imresize
from sklearn.cross_validation import train_test_split
import scipy.io as sio
from skimage import measure
import time
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline

# Data ingestion

The dataset use din this study is provided by TGS. Here is a link to data: https://www.kaggle.com/c/tgs-salt-identification-challenge
We load the names of the files containing the labels. For each image that has a label, we load that name.

In [None]:
image_dir = 'data/train/images/'
image_files = sgm.get_file_path_list(image_dir)
label_dir = 'data/train/masks/'
label_files = sgm.get_file_path_list(label_dir)

Here we load the masks, and convert them into a binary format"
* If the images are not the same resolution, so they are resized to a constant 820x550. Any interpolated label values greater than zero are set to one.
* MXNet requires the input to have dimension <tt>(batch, channel, height, width)</tt>, so these alterations are made.

In [None]:
image1 = np.array(Image.open(image_files[100]).resize((550,820)))
mask1 = np.array(Image.open(label_files[100]).resize((550,820)))

In [None]:
plt.figure(figsize=(12,12))
plt.subplot(131)
plt.title('Image')
plt.imshow(image1)

plt.subplot(132)
plt.imshow(mask1)
plt.title('Salt Mask')

## Stack images

In [None]:
X = []
Y = []
for i in range(len(image_files)):
    mask = (Image.open(label_files[i]))
    mask = (imresize(mask, (820, 550)) > 0).astype(np.uint8) # interpolate to 820 x 550
    image = np.array(Image.open(image_files[i]).resize((550,820)))
    X.append(image)
    Y.append(mask)
X = np.transpose(np.stack(X), axes=(0, 3, 1, 2))
Y = np.expand_dims(np.stack(Y), 1)

We have 4000 observations.

In [None]:
X.shape

# Data splitting and augmentation (random cropping)

We're going to generate a new data-set through random cropping of our existing images. Before we do that, we need to split the data into training and validation data (if we did crops first, and then split, we run the risk of data leakage).

In [None]:
train_X, validation_X, train_Y, validation_Y = train_test_split(X, Y, test_size=0.2, random_state=1984)

In [None]:
train_X_boxes, train_Y_boxes = sgm.extract_boxes(train_X, train_Y)
validation_X_boxes, validation_Y_boxes = sgm.extract_boxes(validation_X, validation_Y)

Next, we generate the random crops.

In [None]:
train_X_crops, train_Y_crops = sgm.generate_random_crops(train_X_boxes, train_Y_boxes, num_patches=3)
validation_X_crops, validation_Y_crops = sgm.generate_random_crops(validation_X_boxes, validation_Y_boxes, num_patches=3)

In [None]:
train_X_crops.shape

Next, we'll save the generated data locally so we can avoid generating again.

In [None]:
if not os.path.exists('/dev/shm/salt/segmentation_data'): 
    os.mkdir('/dev/shm/salt/')
    os.mkdir('/dev/shm/salt/segmentation_data')
np.save('/dev/shm/salt/segmentation_data/train_X_crops.npy', train_X_crops)
np.save('/dev/shm/salt/segmentation_data/train_Y_crops.npy', train_Y_crops)
np.save('/dev/shm/salt/segmentation_data/validation_X_crops.npy', validation_X_crops)
np.save('/dev/shm/salt/segmentation_data/validation_Y_crops.npy', validation_Y_crops)

In [None]:
train_X_crops = np.load('/dev/shm/salt/segmentation_data/train_X_crops.npy')
train_Y_crops = np.load('/dev/shm/salt/segmentation_data/train_Y_crops.npy')
validation_X_crops = np.load('/dev/shm/salt/segmentation_data/validation_X_crops.npy')
validation_Y_crops = np.load('/dev/shm/salt/segmentation_data/validation_Y_crops.npy')

# SageMaker

We're going to proceed by defining the UNet Network for binary segmentation using Sagemaker.

Now we are ready to define a training job in SageMaker to do training at scale.

In [None]:
import boto3
import sagemaker
from sagemaker.mxnet import MXNet
from sagemaker import get_execution_role

In [None]:
role = get_execution_role()
sagemaker_session = sagemaker.Session()

We need to upload the data to S3 so the instances launched for the training job can pull the data down.

In [None]:
inputs = sagemaker_session.upload_data(path='/dev/shm/salt/segmentation_data', key_prefix='sagemaker_data')

Finally, we create the MXNet SageMaker model estimator using the **Bring your own script** paradigm. We've defined a script, <tt>segmentation.py</tt>, that runs the training loop for UNet in MXNet Symbolic. To do this, we follow the conventions defined for the SageMaker Python SDK [here](https://github.com/aws/sagemaker-python-sdk).

## Train

In [None]:
#SAGEMAKER NOTEBOOK CODE
sagemaker_net = MXNet("segmentation.py", 
                  role=role, 
                  train_instance_count=2, 
                  train_instance_type="ml.p3.16xlarge",
                  sagemaker_session=sagemaker_session,
                  framework_version="1.2",
                  hyperparameters={
                                 'data_shape': (3, 256, 256),
                                 'batch_size': 64, 
                                 'epochs': 100, 
                                 'learning_rate': 1E-3, 
                                 'num_gpus': 1,
                                  })

sagemaker_net.fit(inputs)

Once the training is complete, we can launch an endpoint server that serves inference with our trained model.

In [None]:
sagemaker_predictor = sagemaker_net.deploy(initial_instance_count=1, instance_type='ml.p2.xlarge')

## Test

In [None]:
test_iter = mx.io.NDArrayIter(data = validation_X_crops, label=validation_Y_crops, batch_size=1, shuffle=False)

We can send test data to the inference endpoint:

In [None]:
batch = next(test_iter)
data = batch.data[0]
label = batch.label[0]
response = sagemaker_predictor.predict(data.asnumpy().tolist())
output = np.array(response[0])

In [None]:
def post_process_mask(label, p=0.5):
    return (np.where(label > p, 1, 0)).astype('uint8')

In [None]:
width = 12
height = 12
plt.figure(figsize=(width, height))
plt.subplot(331)
plt.title('Input')
plt.imshow(np.transpose(data.asnumpy()[0], (1,2,0)).astype(np.uint8))
plt.subplot(332)
plt.title('Prediction')
plt.imshow(post_process_mask(output[0]), cmap=plt.cm.gray)
plt.subplot(333)
plt.title('Mask')
plt.imshow(label[0][0].asnumpy(), cmap=plt.cm.gray)

Don't forget to delete your endpoint when you're done with it.

In [None]:
sagemaker_net.delete_endpoint()