Image data augmentation on-the-fly by add new class on transforms in PyTorch and torchvision.
Image data augmentation on-the-fly by adding new class on transforms in PyTorch and torchvision.
Normally, we
from torchvision import transforms
for transformation, but some specific transformations (especially for histology image augmentation) are missing.
Thus, we add 4 new transforms class on the basic of
torchvision.transforms
pyfile, which we named as myTransforms.py.
You can call and use it in the same form as
torchvision.transforms
. Or, you can refer to dataAug_myTransforms.py.
Also, you can check the actual effect of myTransforms for data augmentation :)
myTransforms
Randomly perturbe the HED color space value on an RGB pathological image[1].
Args
Example
import myTransforms
imagename = '../data/10-05074_353_49_8178.png'
img = Image.open(imagename) # read the image
preprocess = myTransforms.HEDJitter(theta=0.05)
print(preprocess)
HEPerimg = preprocess(img)
plt.subplot(121)
plt.imshow(img)
plt.subplot(122)
plt.imshow(HEPerimg)
plt.show()
References
[1]. Tellez, D., Balkenhol, M., Otte-Höller, I., van de Loo, R., Vogels, R., Bult, P., ... & Litjens, G. (2018). Whole-slide mitosis detection in H&E breast histology using PHH3 as a reference to train distilled stain-invariant convolutional networks. IEEE transactions on medical imaging, 37(9), 2126-2136.
[2]. Ruifrok, A. C., & Johnston, D. A. (2001). Quantification of histochemical staining by color deconvolution. Analytical and quantitative cytology and histology, 23(4), 291-299.
Random Elastic transformation by CV2 method on image by alpha, sigma parameter.
WARNING: This transform class will spend a lot of CPU time for preprocessing.
Args
__call__
functionExample
import myTransforms
imagename = '../data/10-05074_353_49_8178.png'
img = Image.open(imagename) # read the image
preprocess = myTransforms.RandomElastic(alpha=2, sigma=0.06)
print(preprocess)
elasticimg = preprocess(img, mask=None)
plt.subplot(121)
plt.imshow(img)
plt.subplot(122)
plt.imshow(elasticimg)
plt.show()
```
![Elastic](https://github.com/gatsby2016/Augmentation-PyTorch-Transforms/blob/master/data/elasticimg.gif)
**References**
[affine and elastic transform](https://blog.csdn.net/maliang_1993/article/details/82020596)
[cv2.warpAffine](https://blog.csdn.net/qq_27261889/article/details/80720359)
[scipy.ndimage.map_coordinates](https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html#scipy.ndimage.map_coordinates)
### RandomAffineCV2
Random Affine transformation by CV2 method on image by alpha parameter.
It is different from `torchvision.transforms.RandomAffine`, which is implemented by `PIL.Image` method. We can set BORDER_REFLECT for the area outside the transform in the output image while original `RandomAffine` can only fill by a specified value.
**Args**
- alpha (float): alpha value for affine transformation
- mask (PIL Image) For processing on GroundTruth of segmentation task, if not assign, set None.
**Example**
```python
import myTransforms
imagename = '../data/10-05074_353_49_8178.png'
img = Image.open(imagename) # read the image
preprocess = myTransforms.RandomAffineCV2(alpha=0.1)#alpha \in [0,0.15]
print(preprocess)
affinecvimg = preprocess(img)
plt.subplot(121)
plt.imshow(img)
plt.subplot(122)
plt.imshow(affinecvimg)
plt.show()
Random Gauss Blurring on image by radius parameter. Args
Example
import myTransforms
imagename = '../data/10-05074_353_49_8178.png'
img = Image.open(imagename) # read the image
preprocess = myTransforms.RandomGaussBlur(radius=[0.5, 1.5])
print(preprocess)
blurimg = preprocess(img)
plt.subplot(121)
plt.imshow(img)
plt.subplot(122)
plt.imshow(blurimg)
plt.show()
change torchvision.transforms.RandomRotation
for auto-random select angle from [0, 90, 180, 270] for rotating the image.
Example
import myTransforms
imagename = '../data/10-05074_353_49_8178.png'
img = Image.open(imagename) # read the image
preprocess = myTransforms.AutoRandomRotation()
print(preprocess)
rotateimg = preprocess(img)
plt.subplot(121)
plt.imshow(img)
plt.subplot(122)
plt.imshow(rotateimg)
plt.show()