Pytorch 学习回顾--线性回归
生成数据集
我们可以调用框架中现有的 API 来读取数据。我们将features
和labels
作为 API 的参数传递,并在实例化数据迭代器对象时指定batch_size
。此外,布尔值is_train
表示是否希望数据迭代器对象在每个迭代周期内打乱数据。
def load_array(data_arrays, batch_size, is_train=True):
"""构造一个 PyTorch 数据迭代器。"""
dataset = data.TensorDataset(*data_arrays)
return data.DataLoader(dataset, batch_size, shuffle=is_train)
batch_size = 10
data_iter = load_array((features, labels), batch_size)