数据增强

数据增强的方式有很多,比如对图像进行几何变换(如翻转、旋转、变形、缩放等)、颜色变换(包括噪声、模糊、颜色变换、檫除、填充等),将有限的数据,进行充分的利用。这里将介绍的仅仅是对图像数据进行任意方向的移动操作(上下左右)来扩充数据。

这里将使用scipy中的shift变换工具(from scipy.ndimage.interpolation import shift)

数据增强

   常用的参数:input输入图像数据为ndarray类型的,

        shift参数代表表示各个维度的偏移量[1,1]表示第一个第二个维度均偏移1,

        cval参数代表偏移后原来位置用什么来填充

from scipy.ndimage.interpolation import shift
def shift_digit(digit_array,dx,dy,new = 0):
    return shift(digit_array.reshape(28,28),[dy,dx],cval = new).reshape(784)
plot_digit(shift_digit(some_digit,5,1,new =100))

数据增强

   一个简单的数据偏移完成,接下来对整个训练集进行扩充

X_train_expanded = [X_train]
y_train_expanded = [y_train]
for dx,dy in ((1,0),(-1,0),(0,1),(0,-1)):
    shifted_image = np.apply_along_axis(shift_digit,axis = 1,arr = X_train,dx = dx,dy = dy)
    X_train_expanded.append(shifted_image)
    y_train_expanded.append(y_train)

X_train_expanded = np.concatenate(X_train_expanded)
y_train_expanded = np.concatenate(y_train_expanded)
X_train_expanded.shape,y_train_expanded.shape

数据增强

   数据增加大了30万之多,有了更多的数据,接下来进行训练、预测,计算精度

knn_clf.fit(X_train_expanded,y_train_expanded)

数据增强

 

 

y_knn_expanded_pred = knn_clf.predict(X_test)
accuracy_score(y_test,y_knn_expanded_pred)

数据增强

 另一种表示方式:

def shift_image(image,dx,dy):
    image = image.reshape((28,28))
    shifted_image = shift(image,[dy,dx],cval = 0,mode = 'constant')
    return shifted_image.reshape([-1])
image = X_train[1000]
shifted_image_down = shift_image(image,0,5)
shifted_image_left = shift_image(image,-5,0)

plt.figure(figsize=(12,3))
plt.subplot(131)
plt.title("Original",fontsize= 14)
plt.imshow(image.reshape(28,28),interpolation='nearest',cmap = 'Greys')
plt.subplot(132)
plt.title("shifted down",fontsize= 14)
plt.imshow(shifted_image_down.reshape(28,28),interpolation='nearest',cmap = 'Greys')
plt.subplot(133)
plt.title("shifted left",fontsize= 14)
plt.imshow(shifted_image_left.reshape(28,28),interpolation='nearest',cmap = 'Greys')
plt.show()

数据增强

X_train_augmented = [image for image in X_train]
y_train_augmented = [label for label in y_train]

for dx,dy in ((1,0),(-1,0),(0,1),(0,-1)):
    for image,label in zip(X_train,y_train):
        X_train_augmented.append(shift_image(image,dx,dy))
        y_train_augmented.append(label)
X_train_augmented = np.array(X_train_augmented)
y_train_augmented = np.array(y_train_augmented)
#打乱顺序
shuffle_idx = np.random.permutation(len(X_train_augmented)) X_train_augmented = X_train_augmented[shuffle_idx] y_train_augmented = y_train_augmented[shuffle_idx]
knn_clf = KNeighborsClassifier(**grid_search.best_params_)
knn_clf.fit(X_train_augmented,y_train_augmented)
y_pred = knn_clf.predict(X_test)
accuracy_score(y_test,y_pred)

数据增强

 此时准确率已达到97%以上

关于knn_clf = KNeighborsClassifier(**grid_search.best_params_)中的**犯傻了很久,**代表着该参数中包含了多个参数,在C++中也会有这种参数表示,
也可参看python中*args与**kwargs的介绍(https://pythontips.com/2013/08/04/args-and-kwargs-in-python-explained/)

当然,scipy.ndimage.interpolation 也包含了其他的数据增强的方法,如旋转、缩放等(参考:https://blog.csdn.net/songchunxiao1991/article/details/88531086)

原文链接: https://www.cnblogs.com/whiteBear/p/12455432.html

欢迎关注

微信关注下方公众号,第一时间获取干货硬货;公众号内回复【pdf】免费获取数百本计算机经典书籍;

也有高质量的技术群,里面有嵌入式、搜广推等BAT大佬

    数据增强

原创文章受到原创版权保护。转载请注明出处:https://www.ccppcoding.com/archives/372369

非原创文章文中已经注明原地址,如有侵权,联系删除

关注公众号【高性能架构探索】,第一时间获取最新文章

转载文章受原作者版权保护。转载请注明原作者出处!

(0)
上一篇 2023年3月3日 上午11:11
下一篇 2023年3月3日 上午11:12

相关推荐