跳转至

Machine Learning

Pytorch 学习回顾--线性回归

生成数据集

我们可以调用框架中现有的 API 来读取数据。我们将featureslabels作为 API 的参数传递,并在实例化数据迭代器对象时指定batch_size。此外,布尔值is_train表示是否希望数据迭代器对象在每个迭代周期内打乱数据。

def load_array(data_arrays, batch_size, is_train=True): 
    """构造一个 PyTorch 数据迭代器。"""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)
batch_size = 10
data_iter = load_array((features, labels), batch_size)