我有一个csv数据,数据的第一列是“标签”,第一列之后的第784列包含图像(28*28)格式的表示。
我使用以下函数创建了numpy数组的元组。
下一步是我尝试将此数据集分割为所需的80%/20%分割,以进行培训和验证。为此,我使用
loadData()
方法如下。当我运行函数进行拆分时,会出现错误
无法将输入数组从形状(5851784)广播到形状(5851)错误。
我的问题是,我只想拆分使用
load(filename)
分为两个数据集。有什么帮助吗?
filename=dir_path+'train1.csv'
def load(filename):
# read file into a list of rows
with open(filename, 'rU') as csvfile:
lines = csv.reader(csvfile, delimiter=',')
rows = list(lines)
# create empty numpy arrays of the required size
data = np.empty((len(rows), len(rows[0])-1), dtype=np.float64)
expected = np.empty((len(rows),), dtype=np.int64)
# fill array with data from the csv-rows
for i, row in enumerate(rows):
data[i,:] = row[1:]
expected[i] = row[0]
training_data = data, expected
return training_data
print load(filename)
后果
(array([[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],
...,
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.]]), array([1, 1, 1, ..., 1, 1, 1]))
运行此函数以拆分:
def loadData():
train_data= load(train_name)
training_data,validation_data =np.split(train_data, [int(.8 * len(train_data))])
return train_data
print loadData()
结果:
无法将输入数组从形状(5851784)广播到形状(5851)
解决方案:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
train_name=dir_path+'train8.csv'
test_name=dir_path+'test8.csv'
def load(filename):
# read file into a list of rows
with open(filename, 'rU') as csvfile:
lines = csv.reader(csvfile, delimiter=',')
rows = list(lines)
# create empty numpy arrays of the required size
data = np.empty((len(rows), len(rows[0])-1), dtype=np.float64)
expected = np.empty((len(rows),), dtype=np.int64)
# fill array with data from the csv-rows
for i, row in enumerate(rows):
data[i,:] = row[1:]
expected[i] = row[0]
result_data = data, expected
return result_data
def loadData():
train_data= load(train_name)[0]
labels= load(train_name)[1]
test_data= load(test_name)
x_train, x_test, y_train, y_test = train_test_split(train_data, labels, test_size=0.33)
training_data = (x_train, y_train)
validation_data=(x_test, y_test)
return (training_data, validation_data, test_data)
此解决方案将匹配mnist数据集