question
现在,我想训练一个U-net来做多类预测。这是我的代码:
import os
import sys
import random
import warnings
import scipy.misc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from itertools import chain
from skimage.io import imread, imshow, imread_collection, concatenate_images, imsave
from skimage.transform import resize
from skimage.morphology import label
import cv2
from keras.models import Model, load_model
from keras.layers import Input
from keras.layers.core import Lambda
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D
from keras.layers.merge import concatenate
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import backend as K
from keras import optimizers
from keras.utils import multi_gpu_model
import tensorflow as tf
# Set some parameters
IMG_WIDTH = 256
IMG_HEIGHT = 256
IMG_CHANNELS = 3
mask_dim = 256*256
TRAIN_IM = './train_im/'
TRAIN_MASK = './train_mask/'
num_training = len(os.listdir(TRAIN_IM))
num_test = len(os.listdir(TEST_PATH))
# Get and resize train images
X_train = np.zeros((num_training, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
Y_train = np.zeros((num_training, mask_dim, 4), dtype=np.bool)
print('Getting and resizing train images and masks ... ')
sys.stdout.flush()
#load training images
for count, filename in tqdm(enumerate(os.listdir(TRAIN_IM)), total=num_training):
img = imread(os.path.join(TRAIN_IM, filename))[:,:,:IMG_CHANNELS]
img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
X_train[count] = img
name, ext = os.path.splitext(filename)
mask_name = name + '_mask' + ext
mask = cv2.imread(os.path.join(TRAIN_MASK, mask_name))[:,:,:1]
mask = np.squeeze(mask)
Y_train[count] = mask
# Build U-Net model
''' ...import VGG16 layers to use pretrained weights...'''
width = 32
c1 = Conv2D(width, (3, 3), activation='elu', padding='same') (s)
c1 = Conv2D(width, (3, 3), activation='elu', padding='same') (c1)
p1 = MaxPooling2D((2, 2)) (c1)
c2 = Conv2D(width*2, (3, 3), activation='elu', padding='same') (p1)
c2 = Conv2D(width*2, (3, 3), activation='elu', padding='same') (c2)
p2 = MaxPooling2D((2, 2)) (c2)
c3 = Conv2D(width*4, (3, 3), activation='elu', padding='same') (p2)
c3 = Conv2D(width*4, (3, 3), activation='elu', padding='same') (c3)
p3 = MaxPooling2D((2, 2)) (c3)
c4 = Conv2D(width*8, (3, 3), activation='elu', padding='same') (p3)
c4 = Conv2D(width*8, (3, 3), activation='elu', padding='same') (c4)
p4 = MaxPooling2D(pool_size=(2, 2)) (c4)
c5 = Conv2D(width*16, (3, 3), activation='elu', padding='same') (p4)
c5 = Conv2D(width*16, (3, 3), activation='elu', padding='same') (c5)
u6 = Conv2DTranspose(width*8, (2, 2), strides=(2, 2), padding='same') (c5)
u6 = concatenate([u6, c4, ll['block4_conv3']])
c6 = Conv2D(width*8, (3, 3), activation='elu', padding='same') (u6)
c6 = Conv2D(width*8, (3, 3), activation='elu', padding='same') (c6)
u7 = Conv2DTranspose(width*4, (2, 2), strides=(2, 2), padding='same') (c6)
u7 = concatenate([u7, c3, ll['block3_conv3']])
c7 = Conv2D(width*4, (3, 3), activation='elu', padding='same') (u7)
c7 = Conv2D(width*4, (3, 3), activation='elu', padding='same') (c7)
u8 = Conv2DTranspose(width*2, (2, 2), strides=(2, 2), padding='same') (c7)
u8 = concatenate([u8, c2, ll['block2_conv2']])
c8 = Conv2D(width*2, (3, 3), activation='elu', padding='same') (u8)
c8 = Conv2D(width*2, (3, 3), activation='elu', padding='same') (c8)
u9 = Conv2DTranspose(width, (2, 2), strides=(2, 2), padding='same') (c8)
u9 = concatenate([u9, c1, ll['block1_conv2']], axis=3)
c9 = Conv2D(width, (3, 3), activation='elu', padding='same') (u9)
c9 = Conv2D(width, (3, 3), activation='elu', padding='same') (c9)
outputs = Conv2DTranspose(1, (1, 1), activation='sigmoid') (c9)
model = Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer='adam', loss = bce_dice, metrics = ['accuracy'])
model.summary()
earlystopper = EarlyStopping(patience=20, verbose=1)
checkpointer = ModelCheckpoint('model-dsbowl2018-1.h5', verbose=1, save_best_only=True)
results = model.fit(X_train, Y_train, validation_split=0, batch_size=1, epochs=100,
callbacks=[earlystopper, checkpointer])
model.load_weights('model-dsbowl2018-1.h5')
因为它给了我以下错误: