代码之家  ›  专栏  ›  技术社区  ›  Qubix

生成器不会浏览所有文件

  •  0
  • Qubix  · 技术社区  · 6 年前

    我正在用来自多个 .csv 文件和我发现我的代码读取文件,但模型仍然在一个单一的培训。我的代码的相关部分是:

    def get_data(datasets_path):
        ''' 
        Returns the dataframes.
        '''
        full_path = datasets_path + "*.csv"
        for data_fname in glob.glob(full_path):
                df = pd.read_csv(data_fname)
                processed_df = __preprocessor(df)
                scaler = MinMaxScaler()
                transformed_df = scaler.fit_transform(processed_df)
                return transformed_df
    
    
    def batch_generator(X, batch_size=16, shuffle=False):
        '''
        Return a random sample from X.
        '''
        count = 0
        while True:
            if shuffle:
                idx = np.random.randint(0, X.shape[0], batch_size)
                data = X[idx]
            else:
                indices = list(n for n in range(X.shape[0]))
                data = X[indices[count*batch_size : (count+1)*batch_size]]
                count +=1
            yield (data, data)
    

    data = get_data(path_to_datasets)
    x_train, x_test = train_test_split(data, test_size=0.2, random_state=42, shuffle=False)
    
    x_train = np.expand_dims(x_train, axis=1)
    x_test = np.expand_dims(x_test, axis=1)
    
    train_gen = batch_generator(x_train, batch_size=32)
    valid_gen = batch_generator(x_test, batch_size=32)
    

    然后我定义一个简单的模型并用

    model.fit_generator(
        generator=train_gen,
        epochs=1,
        steps_per_epoch=x_train.shape[0] // 32,
        validation_data=valid_gen,
        validation_steps=x_test.shape[0] // 32)
    

    问题是,这似乎是从一个 .csv文件 把它们都归档而不是全部看一遍,我不明白为什么。

    1 回复  |  直到 6 年前
        1
  •  1
  •   Mitiku    6 年前

    probelem是for循环中的返回语句。处理完单个文件后, get_data

    def get_data(datasets_path):
        ''' 
        Returns the dataframes.
        '''
        full_path = datasets_path + "*.csv"
        for data_fname in glob.glob(full_path):
                df = pd.read_csv(data_fname)
                processed_df = __preprocessor(df)
                scaler = MinMaxScaler()
                transformed_df = scaler.fit_transform(processed_df)
                yield transformed_df