【发布时间】:2021-09-22 10:26:54
【问题描述】:
我正在从事一个与阿尔茨海默病诊断相关的机器学习项目。由于类不平衡,需要对数据进行过采样。因此,使用了 SMOTE。
下面给出了包含 SMOTE 的代码块。遇到此代码块时,Google Colab RAM 会崩溃。
sm = SMOTE(random_state=42)
train_data, train_labels = sm.fit_resample(train_data.reshape(-1, IMG_SIZE * IMG_SIZE * 3), train_labels)
train_data = train_data.reshape(-1, IMG_SIZE, IMG_SIZE, 3)
print(train_data.shape, train_labels.shape)
这里给出了这个sn-p之前的代码。
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
import matplotlib.pyplot as plt
import os
from distutils.dir_util import copy_tree, remove_tree
from PIL import Image
from random import randint
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split
from sklearn.metrics import matthews_corrcoef as MCC
from sklearn.metrics import balanced_accuracy_score as BAS
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow_addons as tfa
from keras.utils.vis_utils import plot_model
from tensorflow.keras import Sequential, Input
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.layers import Conv2D, Flatten
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.applications import InceptionV3
from tensorflow.keras.preprocessing.image import ImageDataGenerator as IDG
from tensorflow.keras.layers import SeparableConv2D, BatchNormalization, MaxPool2D
base_dir = '/content/drive/MyDrive/Alzheimers_Diag_GCUThesis_2021/Alzheimer_sDataset/'
root_dir = "./"
test_dir = base_dir + "test/"
train_dir = base_dir + "train/"
work_dir = root_dir + "dataset2/"
if os.path.exists(work_dir):
remove_tree(work_dir)
os.mkdir(work_dir)
copy_tree(train_dir, work_dir)
copy_tree(test_dir, work_dir)
print("Working Directory Contents:", os.listdir(work_dir))
WORK_DIR = './dataset'
CLASSES = [ 'NonDemented',
'VeryMildDemented',
'MildDemented',
'ModerateDemented']
IMG_SIZE = 176
IMAGE_SIZE = [176, 176]
DIM = (IMG_SIZE, IMG_SIZE)
ZOOM = [.99, 1.01]
BRIGHT_RANGE = [0.8, 1.2]
HORZ_FLIP = True
FILL_MODE = "constant"
DATA_FORMAT = "channels_last"
work_dr = IDG(rescale = 1./255, brightness_range=BRIGHT_RANGE, zoom_range=ZOOM, data_format=DATA_FORMAT, fill_mode=FILL_MODE, horizontal_flip=HORZ_FLIP)
train_data_gen = work_dr.flow_from_directory(directory=WORK_DIR, target_size=DIM, batch_size=6500, shuffle=False)
#Retrieving the data from the ImageDataGenerator iterator
train_data, train_labels = train_data_gen.next()
#Getting to know the dimensions of our dataset
print(train_data.shape, train_labels.shape)
【问题讨论】:
-
请看如何发minimal reproducible example(强调ninimal); 删除与问题无关的代码(包括不相关的导入)。看看为什么a wall of code isn't helpful。
标签: python machine-learning smote