代码之家  ›  专栏  ›  技术社区  ›  Davide Fiocco

如何从pyspark数据帧创建PyTorch数据集(使用Databricks)?

  •  0
  • Davide Fiocco  · 技术社区  · 3 年前

    Dataset 使用pyspark数据帧作为原始数据(不确定这是正确的方法……)。

    为了预处理文本,我使用 transformers 图书馆和图书馆 tokenizing_UDF

    这个 数据集 DataLoader 训练ML模型。

    我现在拥有的是:

    import pandas as pd # ideally I'd like to get rid of pandas here
    import torch
    from torch.utils.data.dataset import Dataset
    from transformers import BertTokenizer
    from pyspark.sql import types as T
    from pyspark.sql import functions as F
    
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    text = ["This is a test.", "This is not a test."]*100
    label = [1, 0]*100
    
    df = sqlContext.createDataFrame(zip(text, label), schema=['text', 'label'])
    tokenizing_UDF = udf(lambda t: tokenizer.encode(t),  T.ArrayType(T.LongType())) 
    df = df.withColumn("tokenized", tokenizing_UDF(F.col("text"))) # not sure this is the right way
    df = df.toPandas() # ugly
    
    class TokenizedDataset(Dataset):
        """needs refactoring..."""
        def __init__(self, df):
            self.data = df
            
        def __getitem__(self, index):
            text = self.data.loc[index].tokenized
            text = torch.LongTensor(text)
            label = self.data.loc[index].label
            return (text, label)
    
        def __len__(self):
            count = len(self.data)
            return count
    
    dataset = TokenizedDataset(df) # slow...
    

    我现在调用 .toPandas() 所以我的 TokenizedDataset 可以处理数据帧。

    这是明智的做法吗? 如果是,我应该如何修改 标记化数据集 直接处理pyspark数据帧的代码? https://github.com/uber/petastorm 相反呢?

    0 回复  |  直到 3 年前