PM2.5浓度估计
目标:
使用深度神经网络(Deep Neural network, DNN)来解决回归问题
学习基本的DNN训练方法
熟悉目前主流的深度学习框架Pytorch
本次作业我们提供了训练集、验证集(均带有标签)以及测试集(不带标签),你可以通过训练集和验证集寻找最好的模型,并在测试集上得到你的预测结果。
最后你需要将预测结果提交到kaggle上(https://www.kaggle.com/competitions/ml-dl-2022-hw1/overview), 系统会自动返回给你测试集误差(MSE)。
总的来说,本次作业的目标就是基于我们给定的代码,发挥你在课堂上学习到的知识,修改代码,从而得到尽可能低的测试集误差。
研究准备
下载数据
在这里,我们从google drive上下载本次作业需要的数据,包括train.csv, val.csv, test.csv
下载完成后,点击左边此处可以找到三个csv文件,双击文件可以预览其结构,其中,每一行为一个观测数据,每一列为一个特征,具体地:
station:站点编号
date:数据采集日期
lat、lon:站点经纬度
AOD:气溶胶光学厚度
ET:蒸发量
BLH:边界层高度
TEM:温度
NDVI:植被指数
SP:海平面气压
RH:相对湿度
DEM:地面高程
NTL:夜间灯光指数
LUC:土地覆盖类型
PRE: 降水
WS、WD:风速风向
[ ]:
train_path = 'train.csv'
val_path = 'val.csv'
test_path = 'test.csv'
!gdown --id 15BfpT7ieOq_RaMKz9fZQF1gpQLEu83be --output 'train.csv'
!gdown --id 1w4v26_bRzrIUTKQSTtPl7Klu1XZtZITX --output 'val.csv'
!gdown --id 1KRFj-T2fS_K8-4tOX_pN7ugQNz8bseVn --output 'test.csv'
/usr/local/lib/python3.7/dist-packages/gdown/cli.py:131: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.
category=FutureWarning,
Downloading...
From: https://drive.google.com/uc?id=15BfpT7ieOq_RaMKz9fZQF1gpQLEu83be
To: /content/train.csv
100% 1.72M/1.72M [00:00<00:00, 115MB/s]
/usr/local/lib/python3.7/dist-packages/gdown/cli.py:131: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.
category=FutureWarning,
Downloading...
From: https://drive.google.com/uc?id=1w4v26_bRzrIUTKQSTtPl7Klu1XZtZITX
To: /content/val.csv
100% 520k/520k [00:00<00:00, 97.9MB/s]
/usr/local/lib/python3.7/dist-packages/gdown/cli.py:131: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.
category=FutureWarning,
Downloading...
From: https://drive.google.com/uc?id=1KRFj-T2fS_K8-4tOX_pN7ugQNz8bseVn
To: /content/test.csv
100% 657k/657k [00:00<00:00, 88.5MB/s]
设置环境
接下来,我们导入代码所需要的运行环境
[ ]:
# 如果你引入了新的库,可以从这里导入
import pandas # 提供dataframe,用来处理csv数据
import torch # pytorch,目前主流的深度学习框架
from torch.utils import data # 从torch.utils导入data模块,用来构建pytorch中的数据结构--->dataloader
from torch import nn # 从torch导入nn模块,用来提供基本的神经网络接口,如全连接层nn.Linear()
import numpy as np # 导入numpy,并将其重命名为np,numpy是一个高效的用来处理矩阵的库
from tqdm import tqdm # tqdm是一个可视化代码进程的模块
import os # os库用来处理磁盘读写过程
from sklearn.preprocessing import StandardScaler # 这是一个用来数据标准化的模块
import csv # 读写csv文件的库
normalize = StandardScaler() # 实例化StandardScaler
# 为了处理大批量数据,在目前主流的深度学习框架下,数据都是以矩阵的形式表示的
# 而对于CPU来讲,其对于矩阵数据处理效率欠佳,而GPU(显卡)则擅长处理矩阵,因此深度学习模型都会部署到GPU上
# 这段代码的意思是,判断目前的设备中是否有安装CUDA(GPU加速模块)。
# 如果有的话,我们就让模型部署到GPU上,如果没有,我们继续用CPU(速度会慢一点)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
进行实验
构建数据
[ ]:
# 这段代码用来创建pytorch所需的基本数据结构,data.Dataset
class PM2_5Dataset(data.Dataset):
def __init__(self, csv_path, mode="train"):
super(PM2_5Dataset, self).__init__()
self.csv_content = pandas.read_csv(csv_path)
self.used_column = ["AOD", "ET", "BLH", "TEM", "NDVI", "SP", "RH", "DEM", "NTL", "PRE", "WS", "WD"] # TODO:这些特征都是有用的吗?修改这行代码,可以修改模型输入的特征
self.target_column = "PM2.5"
self.mode = mode
self.dim = len(self.used_column)
if mode in ["train", "val"]:
self.input_data, self.target_data = self.process_csv_with_gt()
else:
self.input_data = self.process_csv_without_gt()
def process_csv_with_gt(self):
input_data = [list(self.csv_content[i]) for i in self.used_column]
input_data = np.array(input_data)
input_data = input_data.T
input_data = normalize.fit_transform(input_data)
target_data = np.array(self.csv_content[self.target_column])
return input_data, target_data
def process_csv_without_gt(self):
input_data = [list(self.csv_content[i]) for i in self.used_column]
input_data = np.array(input_data)
input_data = input_data.T
input_data = normalize.fit_transform(input_data)
return input_data
def __getitem__(self, index):
if self.mode in ["train", "val"]:
data, gt = self.input_data[index, :], self.target_data[index]
data = torch.from_numpy(data)
gt = torch.tensor(gt)
return data, gt
else:
data = self.input_data[index, :]
return torch.from_numpy(data)
def __len__(self):
return self.input_data.shape[0]
def prepare_dataset(csv_path, batch_size, mode):
dataset = PM2_5Dataset(csv_path, mode)
# 在pytorch中,dataset需要用dataloader封装
if mode in ["train", "val"]:
dataloader = data.DataLoader(dataset, batch_size, num_workers=0, pin_memory=True)
else:
dataloader = data.DataLoader(dataset, 1, num_workers=0, pin_memory=True, shuffle=False)
return dataloader
构建模型
[ ]:
# 这是一个基本的神经网络模型,包含三个全连接层,激活函数为Relu
class MyNeuralNet(nn.Module):
# TODO:增加约束,如添加dropput,batchnorm等。增删全连接层数。修改神经元数量。修改连接方式等。
def __init__(self, input_dim):
super(MyNeuralNet, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1),
)
def forward(self, input):
predict = self.net(input)
return predict
训练、验证与测试
这部分提供了训练、测试、验证的代码,你可以修改。例如修改训练代码部分的损失函数,测试代码中的推理策略(boosting,bagging),将验证集并入到训练集中一起训练,等
[ ]:
def train(train_dataset, val_dataset, config, model):
"""训练代码"""
n_epochs = config['n_epochs'] # 训练次数
optimizer = getattr(torch.optim, config['optimizer'])(model.parameters(), **config['optim_hparas']) # 优化器
model.train() # 设置模型为训练模式(梯度可以回传)
pbar = tqdm(total=n_epochs, desc="Train Mode", unit="epoch")
loss = nn.MSELoss() # 损失函数为MSE(均方误差)
loss_log = {"train_loss": [], "val_loss": []}
min_mse = 1e16
for epoch in range(n_epochs):
pbar.update()
epoch_mse = 0.0
for x, y in train_dataset:
optimizer.zero_grad() # 将优化器里储存的梯度清零
x, y = x.to(device).float(), y.to(device).float()
y = y.unsqueeze(1)
pred = model(x)
mse_loss = loss(pred, y) # 计算损失
mse_loss.backward() # 损失回传
optimizer.step() # 梯度更新
loss_log["train_loss"].append(mse_loss.detach().cpu().item())
epoch_mse += mse_loss.detach().cpu().item()
pbar.set_postfix(epoch=epoch + 1, mse=epoch_mse / len(train_dataset))
val_mse = val(val_dataset, model)
if val_mse < min_mse:
min_mse = val_mse
torch.save(model.state_dict(), config['save_path'])
print(f"best model saved! val mse:{val_mse},epoch:{epoch + 1}")
loss_log["val_loss"].append(val_mse)
return min_mse, loss_log
def val(val_dataset, model):
"""验证代码"""
model.eval() # 设置为验证模式(梯度不回传)
total_loss = 0
loss = nn.MSELoss()
for x, y in val_dataset:
x, y = x.to(device).float(), y.to(device).float()
y = y.unsqueeze(1)
with torch.no_grad():
pred = model(x) # 预测结果
mse_loss = loss(pred, y) # 计算损失
total_loss += mse_loss.detach().cpu().item()
total_loss = total_loss / len(val_dataset)
return total_loss
def test(test_set, model):
"""测试代码"""
model.eval() #
preds = []
for x in test_set:
x = x.to(device).float()
with torch.no_grad():
pred = model(x)
preds.append(pred.detach().cpu())
return preds
处理预测结果
[ ]:
# 这部分不需要修改
def process_preds_to_csv(preds):
with open("MyPred.csv", "w", newline='') as f:
writer = csv.writer(f)
writer.writerow(['id', 'PM2.5'])
for i, p in enumerate(preds):
p = p.cpu().item()
writer.writerow([i, p])
主程序
[ ]:
model = None
os.makedirs('models', exist_ok=True)
# 模型超参数设置
config = {
'n_epochs': 200, # 训练次数
'batch_size': 128, # mini-batch 的大小
'optimizer': 'SGD', # 优化算法 (optimizer in torch.optim)
'optim_hparas': { # 优化器的超参数(取决于你用了什么优化算法)
'lr': 0.001,
'momentum':0.9 # SGD的学习率 # SGD的动量
},
'save_path': 'models/model.pth' # 存储路径
}
# 构建dataset
train_dataset = prepare_dataset(r"train.csv", config["batch_size"], mode="train")
val_dataset = prepare_dataset(r"val.csv", config["batch_size"], mode="val")
test_dataset = prepare_dataset(r"test.csv", config["batch_size"], mode="test")
# 实例化模型
model = MyNeuralNet(train_dataset.dataset.dim).to(device)
# 开始训练与验证
min_mse, loss_log = train(train_dataset, val_dataset, config, model)
# 测试模型
preds = test(test_dataset, model)
# 将测试的结果保存到CSV文件里
process_preds_to_csv(preds)
# 执行完成后,可以从左侧下载预测结果(MyPred.csv),并将其提交到Kaggle上,Kaggle会反馈给你测试集的评分(MSE)
Train Mode: 0%| | 0/200 [00:00<?, ?epoch/s]
Train Mode: 0%| | 1/200 [00:00<00:36, 5.47epoch/s, epoch=1, mse=673]
Train Mode: 1%| | 2/200 [00:00<00:23, 8.53epoch/s, epoch=1, mse=673]
best model saved! val mse:687.7400957743326,epoch:1
Train Mode: 1%| | 2/200 [00:00<00:23, 8.53epoch/s, epoch=2, mse=439]
Train Mode: 2%|▏ | 3/200 [00:00<00:34, 5.67epoch/s, epoch=2, mse=439]
best model saved! val mse:424.2332191467285,epoch:2
Train Mode: 2%|▏ | 3/200 [00:00<00:34, 5.67epoch/s, epoch=3, mse=399]
Train Mode: 2%|▏ | 4/200 [00:00<00:40, 4.79epoch/s, epoch=3, mse=399]
Train Mode: 2%|▏ | 4/200 [00:00<00:40, 4.79epoch/s, epoch=4, mse=394]
Train Mode: 2%|▎ | 5/200 [00:01<00:43, 4.45epoch/s, epoch=4, mse=394]
Train Mode: 2%|▎ | 5/200 [00:01<00:43, 4.45epoch/s, epoch=5, mse=399]
Train Mode: 3%|▎ | 6/200 [00:01<00:44, 4.40epoch/s, epoch=5, mse=399]
Train Mode: 3%|▎ | 6/200 [00:01<00:44, 4.40epoch/s, epoch=6, mse=406]
Train Mode: 4%|▎ | 7/200 [00:01<00:45, 4.20epoch/s, epoch=6, mse=406]
Train Mode: 4%|▎ | 7/200 [00:01<00:45, 4.20epoch/s, epoch=7, mse=387]
Train Mode: 4%|▍ | 8/200 [00:01<00:46, 4.11epoch/s, epoch=7, mse=387]
Train Mode: 4%|▍ | 8/200 [00:01<00:46, 4.11epoch/s, epoch=8, mse=384]
Train Mode: 4%|▍ | 9/200 [00:02<00:46, 4.11epoch/s, epoch=8, mse=384]
Train Mode: 4%|▍ | 9/200 [00:02<00:46, 4.11epoch/s, epoch=9, mse=370]
Train Mode: 5%|▌ | 10/200 [00:02<00:45, 4.15epoch/s, epoch=9, mse=370]
Train Mode: 5%|▌ | 10/200 [00:02<00:45, 4.15epoch/s, epoch=10, mse=369]
Train Mode: 6%|▌ | 11/200 [00:02<00:46, 4.07epoch/s, epoch=10, mse=369]
Train Mode: 6%|▌ | 11/200 [00:02<00:46, 4.07epoch/s, epoch=11, mse=364]
Train Mode: 6%|▌ | 12/200 [00:02<00:45, 4.16epoch/s, epoch=11, mse=364]
Train Mode: 6%|▌ | 12/200 [00:02<00:45, 4.16epoch/s, epoch=12, mse=364]
Train Mode: 6%|▋ | 13/200 [00:02<00:43, 4.26epoch/s, epoch=12, mse=364]
Train Mode: 6%|▋ | 13/200 [00:03<00:43, 4.26epoch/s, epoch=13, mse=363]
Train Mode: 7%|▋ | 14/200 [00:03<00:47, 3.89epoch/s, epoch=13, mse=363]
Train Mode: 7%|▋ | 14/200 [00:03<00:47, 3.89epoch/s, epoch=14, mse=359]
Train Mode: 8%|▊ | 15/200 [00:03<00:47, 3.88epoch/s, epoch=14, mse=359]
Train Mode: 8%|▊ | 15/200 [00:03<00:47, 3.88epoch/s, epoch=15, mse=356]
Train Mode: 8%|▊ | 16/200 [00:03<00:47, 3.90epoch/s, epoch=15, mse=356]
Train Mode: 8%|▊ | 16/200 [00:03<00:47, 3.90epoch/s, epoch=16, mse=351]
Train Mode: 8%|▊ | 17/200 [00:04<00:46, 3.94epoch/s, epoch=16, mse=351]
Train Mode: 8%|▊ | 17/200 [00:04<00:46, 3.94epoch/s, epoch=17, mse=352]
Train Mode: 9%|▉ | 18/200 [00:04<00:47, 3.85epoch/s, epoch=17, mse=352]
Train Mode: 9%|▉ | 18/200 [00:04<00:47, 3.85epoch/s, epoch=18, mse=351]
Train Mode: 10%|▉ | 19/200 [00:04<00:47, 3.84epoch/s, epoch=18, mse=351]
Train Mode: 10%|▉ | 19/200 [00:04<00:47, 3.84epoch/s, epoch=19, mse=347]
Train Mode: 10%|█ | 20/200 [00:04<00:44, 4.00epoch/s, epoch=19, mse=347]
Train Mode: 10%|█ | 20/200 [00:04<00:44, 4.00epoch/s, epoch=20, mse=347]
Train Mode: 10%|█ | 21/200 [00:05<00:44, 4.00epoch/s, epoch=20, mse=347]
Train Mode: 10%|█ | 21/200 [00:05<00:44, 4.00epoch/s, epoch=21, mse=342]
Train Mode: 11%|█ | 22/200 [00:05<00:46, 3.84epoch/s, epoch=21, mse=342]
Train Mode: 11%|█ | 22/200 [00:05<00:46, 3.84epoch/s, epoch=22, mse=340]
Train Mode: 12%|█▏ | 23/200 [00:05<00:47, 3.75epoch/s, epoch=22, mse=340]
best model saved! val mse:422.3498725891113,epoch:22
Train Mode: 12%|█▏ | 23/200 [00:05<00:47, 3.75epoch/s, epoch=23, mse=337]
Train Mode: 12%|█▏ | 24/200 [00:05<00:47, 3.71epoch/s, epoch=23, mse=337]
best model saved! val mse:417.58240636189777,epoch:23
Train Mode: 12%|█▏ | 24/200 [00:06<00:47, 3.71epoch/s, epoch=24, mse=338]
Train Mode: 12%|█▎ | 25/200 [00:06<00:47, 3.67epoch/s, epoch=24, mse=338]
best model saved! val mse:415.9144922892253,epoch:24
Train Mode: 12%|█▎ | 25/200 [00:06<00:47, 3.67epoch/s, epoch=25, mse=335]
Train Mode: 13%|█▎ | 26/200 [00:06<00:46, 3.74epoch/s, epoch=25, mse=335]
Train Mode: 13%|█▎ | 26/200 [00:06<00:46, 3.74epoch/s, epoch=26, mse=334]
Train Mode: 14%|█▎ | 27/200 [00:06<00:47, 3.63epoch/s, epoch=26, mse=334]
best model saved! val mse:410.31239827473956,epoch:26
Train Mode: 14%|█▎ | 27/200 [00:06<00:47, 3.63epoch/s, epoch=27, mse=334]
Train Mode: 14%|█▍ | 28/200 [00:06<00:47, 3.59epoch/s, epoch=27, mse=334]
Train Mode: 14%|█▍ | 28/200 [00:07<00:47, 3.59epoch/s, epoch=28, mse=330]
Train Mode: 14%|█▍ | 29/200 [00:07<00:47, 3.63epoch/s, epoch=28, mse=330]
Train Mode: 14%|█▍ | 29/200 [00:07<00:47, 3.63epoch/s, epoch=29, mse=327]
Train Mode: 15%|█▌ | 30/200 [00:07<00:46, 3.64epoch/s, epoch=29, mse=327]
Train Mode: 15%|█▌ | 30/200 [00:07<00:46, 3.64epoch/s, epoch=30, mse=324]
Train Mode: 16%|█▌ | 31/200 [00:07<00:46, 3.64epoch/s, epoch=30, mse=324]
Train Mode: 16%|█▌ | 31/200 [00:07<00:46, 3.64epoch/s, epoch=31, mse=323]
Train Mode: 16%|█▌ | 32/200 [00:08<00:44, 3.74epoch/s, epoch=31, mse=323]
Train Mode: 16%|█▌ | 32/200 [00:08<00:44, 3.74epoch/s, epoch=32, mse=322]
Train Mode: 16%|█▋ | 33/200 [00:08<00:44, 3.79epoch/s, epoch=32, mse=322]
Train Mode: 16%|█▋ | 33/200 [00:08<00:44, 3.79epoch/s, epoch=33, mse=325]
Train Mode: 17%|█▋ | 34/200 [00:08<00:43, 3.78epoch/s, epoch=33, mse=325]
Train Mode: 17%|█▋ | 34/200 [00:08<00:43, 3.78epoch/s, epoch=34, mse=320]
Train Mode: 18%|█▊ | 35/200 [00:08<00:45, 3.66epoch/s, epoch=34, mse=320]
Train Mode: 18%|█▊ | 35/200 [00:09<00:45, 3.66epoch/s, epoch=35, mse=318]
Train Mode: 18%|█▊ | 36/200 [00:09<00:43, 3.74epoch/s, epoch=35, mse=318]
Train Mode: 18%|█▊ | 36/200 [00:09<00:43, 3.74epoch/s, epoch=36, mse=316]
Train Mode: 18%|█▊ | 37/200 [00:09<00:43, 3.76epoch/s, epoch=36, mse=316]
Train Mode: 18%|█▊ | 37/200 [00:09<00:43, 3.76epoch/s, epoch=37, mse=317]
Train Mode: 19%|█▉ | 38/200 [00:09<00:42, 3.80epoch/s, epoch=37, mse=317]
Train Mode: 19%|█▉ | 38/200 [00:09<00:42, 3.80epoch/s, epoch=38, mse=314]
Train Mode: 20%|█▉ | 39/200 [00:09<00:41, 3.87epoch/s, epoch=38, mse=314]
Train Mode: 20%|█▉ | 39/200 [00:10<00:41, 3.87epoch/s, epoch=39, mse=312]
Train Mode: 20%|██ | 40/200 [00:10<00:39, 4.01epoch/s, epoch=39, mse=312]
Train Mode: 20%|██ | 40/200 [00:10<00:39, 4.01epoch/s, epoch=40, mse=311]
Train Mode: 20%|██ | 41/200 [00:10<00:39, 4.03epoch/s, epoch=40, mse=311]
Train Mode: 20%|██ | 41/200 [00:10<00:39, 4.03epoch/s, epoch=41, mse=308]
Train Mode: 21%|██ | 42/200 [00:10<00:39, 3.97epoch/s, epoch=41, mse=308]
Train Mode: 21%|██ | 42/200 [00:10<00:39, 3.97epoch/s, epoch=42, mse=307]
Train Mode: 22%|██▏ | 43/200 [00:10<00:38, 4.07epoch/s, epoch=42, mse=307]
best model saved! val mse:410.2173728942871,epoch:42
Train Mode: 22%|██▏ | 43/200 [00:11<00:38, 4.07epoch/s, epoch=43, mse=306]
Train Mode: 22%|██▏ | 44/200 [00:11<00:39, 3.99epoch/s, epoch=43, mse=306]
Train Mode: 22%|██▏ | 44/200 [00:11<00:39, 3.99epoch/s, epoch=44, mse=304]
best model saved! val mse:399.73748270670575,epoch:43
Train Mode: 22%|██▎ | 45/200 [00:11<00:38, 4.06epoch/s, epoch=44, mse=304]
Train Mode: 22%|██▎ | 45/200 [00:11<00:38, 4.06epoch/s, epoch=45, mse=303]
Train Mode: 23%|██▎ | 46/200 [00:11<00:37, 4.11epoch/s, epoch=45, mse=303]
Train Mode: 23%|██▎ | 46/200 [00:11<00:37, 4.11epoch/s, epoch=46, mse=305]
Train Mode: 24%|██▎ | 47/200 [00:11<00:40, 3.79epoch/s, epoch=46, mse=305]
Train Mode: 24%|██▎ | 47/200 [00:12<00:40, 3.79epoch/s, epoch=47, mse=304]
best model saved! val mse:396.26459248860675,epoch:46
Train Mode: 24%|██▍ | 48/200 [00:12<00:39, 3.89epoch/s, epoch=47, mse=304]
Train Mode: 24%|██▍ | 48/200 [00:12<00:39, 3.89epoch/s, epoch=48, mse=298]
best model saved! val mse:388.1736437479655,epoch:47
Train Mode: 24%|██▍ | 49/200 [00:12<00:37, 4.00epoch/s, epoch=48, mse=298]
Train Mode: 24%|██▍ | 49/200 [00:12<00:37, 4.00epoch/s, epoch=49, mse=301]
Train Mode: 25%|██▌ | 50/200 [00:12<00:37, 3.97epoch/s, epoch=49, mse=301]
Train Mode: 25%|██▌ | 50/200 [00:12<00:37, 3.97epoch/s, epoch=50, mse=306]
Train Mode: 26%|██▌ | 51/200 [00:12<00:36, 4.04epoch/s, epoch=50, mse=306]
Train Mode: 26%|██▌ | 51/200 [00:13<00:36, 4.04epoch/s, epoch=51, mse=303]
Train Mode: 26%|██▌ | 52/200 [00:13<00:35, 4.13epoch/s, epoch=51, mse=303]
Train Mode: 26%|██▌ | 52/200 [00:13<00:35, 4.13epoch/s, epoch=52, mse=305]
Train Mode: 26%|██▋ | 53/200 [00:13<00:36, 4.08epoch/s, epoch=52, mse=305]
Train Mode: 26%|██▋ | 53/200 [00:13<00:36, 4.08epoch/s, epoch=53, mse=298]
Train Mode: 27%|██▋ | 54/200 [00:13<00:35, 4.06epoch/s, epoch=53, mse=298]
Train Mode: 27%|██▋ | 54/200 [00:13<00:35, 4.06epoch/s, epoch=54, mse=301]
Train Mode: 28%|██▊ | 55/200 [00:13<00:37, 3.85epoch/s, epoch=54, mse=301]
Train Mode: 28%|██▊ | 55/200 [00:14<00:37, 3.85epoch/s, epoch=55, mse=303]
Train Mode: 28%|██▊ | 56/200 [00:14<00:39, 3.67epoch/s, epoch=55, mse=303]
Train Mode: 28%|██▊ | 56/200 [00:14<00:39, 3.67epoch/s, epoch=56, mse=304]
Train Mode: 28%|██▊ | 57/200 [00:14<00:37, 3.86epoch/s, epoch=56, mse=304]
Train Mode: 28%|██▊ | 57/200 [00:14<00:37, 3.86epoch/s, epoch=57, mse=303]
Train Mode: 29%|██▉ | 58/200 [00:14<00:35, 4.03epoch/s, epoch=57, mse=303]
Train Mode: 29%|██▉ | 58/200 [00:14<00:35, 4.03epoch/s, epoch=58, mse=295]
Train Mode: 30%|██▉ | 59/200 [00:14<00:34, 4.08epoch/s, epoch=58, mse=295]
Train Mode: 30%|██▉ | 59/200 [00:15<00:34, 4.08epoch/s, epoch=59, mse=289]
Train Mode: 30%|███ | 60/200 [00:15<00:34, 4.10epoch/s, epoch=59, mse=289]
Train Mode: 30%|███ | 60/200 [00:15<00:34, 4.10epoch/s, epoch=60, mse=297]
Train Mode: 30%|███ | 61/200 [00:15<00:37, 3.72epoch/s, epoch=60, mse=297]
Train Mode: 30%|███ | 61/200 [00:15<00:37, 3.72epoch/s, epoch=61, mse=298]
Train Mode: 31%|███ | 62/200 [00:15<00:36, 3.81epoch/s, epoch=61, mse=298]
Train Mode: 31%|███ | 62/200 [00:15<00:36, 3.81epoch/s, epoch=62, mse=298]
Train Mode: 32%|███▏ | 63/200 [00:15<00:34, 3.94epoch/s, epoch=62, mse=298]
Train Mode: 32%|███▏ | 63/200 [00:16<00:34, 3.94epoch/s, epoch=63, mse=293]
Train Mode: 32%|███▏ | 64/200 [00:16<00:33, 4.09epoch/s, epoch=63, mse=293]
Train Mode: 32%|███▏ | 64/200 [00:16<00:33, 4.09epoch/s, epoch=64, mse=288]
Train Mode: 32%|███▎ | 65/200 [00:16<00:31, 4.23epoch/s, epoch=64, mse=288]
Train Mode: 32%|███▎ | 65/200 [00:16<00:31, 4.23epoch/s, epoch=65, mse=287]
Train Mode: 33%|███▎ | 66/200 [00:16<00:30, 4.35epoch/s, epoch=65, mse=287]
Train Mode: 33%|███▎ | 66/200 [00:16<00:30, 4.35epoch/s, epoch=66, mse=284]
Train Mode: 34%|███▎ | 67/200 [00:16<00:30, 4.29epoch/s, epoch=66, mse=284]
Train Mode: 34%|███▎ | 67/200 [00:16<00:30, 4.29epoch/s, epoch=67, mse=288]
Train Mode: 34%|███▍ | 68/200 [00:17<00:30, 4.40epoch/s, epoch=67, mse=288]
Train Mode: 34%|███▍ | 68/200 [00:17<00:30, 4.40epoch/s, epoch=68, mse=293]
Train Mode: 34%|███▍ | 69/200 [00:17<00:29, 4.38epoch/s, epoch=68, mse=293]
Train Mode: 34%|███▍ | 69/200 [00:17<00:29, 4.38epoch/s, epoch=69, mse=290]
Train Mode: 35%|███▌ | 70/200 [00:17<00:29, 4.45epoch/s, epoch=69, mse=290]
Train Mode: 35%|███▌ | 70/200 [00:17<00:29, 4.45epoch/s, epoch=70, mse=282]
Train Mode: 36%|███▌ | 71/200 [00:17<00:28, 4.51epoch/s, epoch=70, mse=282]
Train Mode: 36%|███▌ | 71/200 [00:17<00:28, 4.51epoch/s, epoch=71, mse=282]
Train Mode: 36%|███▌ | 72/200 [00:17<00:28, 4.45epoch/s, epoch=71, mse=282]
Train Mode: 36%|███▌ | 72/200 [00:18<00:28, 4.45epoch/s, epoch=72, mse=284]
Train Mode: 36%|███▋ | 73/200 [00:18<00:28, 4.47epoch/s, epoch=72, mse=284]
Train Mode: 36%|███▋ | 73/200 [00:18<00:28, 4.47epoch/s, epoch=73, mse=279]
Train Mode: 37%|███▋ | 74/200 [00:18<00:27, 4.52epoch/s, epoch=73, mse=279]
Train Mode: 37%|███▋ | 74/200 [00:18<00:27, 4.52epoch/s, epoch=74, mse=282]
Train Mode: 38%|███▊ | 75/200 [00:18<00:27, 4.59epoch/s, epoch=74, mse=282]
Train Mode: 38%|███▊ | 75/200 [00:18<00:27, 4.59epoch/s, epoch=75, mse=282]
Train Mode: 38%|███▊ | 76/200 [00:18<00:27, 4.46epoch/s, epoch=75, mse=282]
Train Mode: 38%|███▊ | 76/200 [00:18<00:27, 4.46epoch/s, epoch=76, mse=282]
Train Mode: 38%|███▊ | 77/200 [00:19<00:27, 4.44epoch/s, epoch=76, mse=282]
Train Mode: 38%|███▊ | 77/200 [00:19<00:27, 4.44epoch/s, epoch=77, mse=288]
Train Mode: 39%|███▉ | 78/200 [00:19<00:27, 4.39epoch/s, epoch=77, mse=288]
Train Mode: 39%|███▉ | 78/200 [00:19<00:27, 4.39epoch/s, epoch=78, mse=287]
Train Mode: 40%|███▉ | 79/200 [00:19<00:27, 4.34epoch/s, epoch=78, mse=287]
Train Mode: 40%|███▉ | 79/200 [00:19<00:27, 4.34epoch/s, epoch=79, mse=287]
Train Mode: 40%|████ | 80/200 [00:19<00:28, 4.27epoch/s, epoch=79, mse=287]
Train Mode: 40%|████ | 80/200 [00:19<00:28, 4.27epoch/s, epoch=80, mse=276]
Train Mode: 40%|████ | 81/200 [00:19<00:27, 4.26epoch/s, epoch=80, mse=276]
Train Mode: 40%|████ | 81/200 [00:20<00:27, 4.26epoch/s, epoch=81, mse=279]
Train Mode: 41%|████ | 82/200 [00:20<00:27, 4.21epoch/s, epoch=81, mse=279]
Train Mode: 41%|████ | 82/200 [00:20<00:27, 4.21epoch/s, epoch=82, mse=293]
Train Mode: 42%|████▏ | 83/200 [00:20<00:27, 4.24epoch/s, epoch=82, mse=293]
Train Mode: 42%|████▏ | 83/200 [00:20<00:27, 4.24epoch/s, epoch=83, mse=276]
Train Mode: 42%|████▏ | 84/200 [00:20<00:28, 4.10epoch/s, epoch=83, mse=276]
Train Mode: 42%|████▏ | 84/200 [00:20<00:28, 4.10epoch/s, epoch=84, mse=273]
Train Mode: 42%|████▎ | 85/200 [00:20<00:27, 4.13epoch/s, epoch=84, mse=273]
Train Mode: 42%|████▎ | 85/200 [00:21<00:27, 4.13epoch/s, epoch=85, mse=275]
Train Mode: 43%|████▎ | 86/200 [00:21<00:27, 4.19epoch/s, epoch=85, mse=275]
Train Mode: 43%|████▎ | 86/200 [00:21<00:27, 4.19epoch/s, epoch=86, mse=284]
Train Mode: 44%|████▎ | 87/200 [00:21<00:26, 4.26epoch/s, epoch=86, mse=284]
Train Mode: 44%|████▎ | 87/200 [00:21<00:26, 4.26epoch/s, epoch=87, mse=283]
Train Mode: 44%|████▍ | 88/200 [00:21<00:26, 4.22epoch/s, epoch=87, mse=283]
Train Mode: 44%|████▍ | 88/200 [00:21<00:26, 4.22epoch/s, epoch=88, mse=280]
Train Mode: 44%|████▍ | 89/200 [00:21<00:27, 3.99epoch/s, epoch=88, mse=280]
Train Mode: 44%|████▍ | 89/200 [00:22<00:27, 3.99epoch/s, epoch=89, mse=275]
Train Mode: 45%|████▌ | 90/200 [00:22<00:27, 4.01epoch/s, epoch=89, mse=275]
Train Mode: 45%|████▌ | 90/200 [00:22<00:27, 4.01epoch/s, epoch=90, mse=264]
Train Mode: 46%|████▌ | 91/200 [00:22<00:26, 4.08epoch/s, epoch=90, mse=264]
Train Mode: 46%|████▌ | 91/200 [00:22<00:26, 4.08epoch/s, epoch=91, mse=270]
Train Mode: 46%|████▌ | 92/200 [00:22<00:26, 4.04epoch/s, epoch=91, mse=270]
Train Mode: 46%|████▌ | 92/200 [00:22<00:26, 4.04epoch/s, epoch=92, mse=272]
Train Mode: 46%|████▋ | 93/200 [00:22<00:26, 4.11epoch/s, epoch=92, mse=272]
Train Mode: 46%|████▋ | 93/200 [00:23<00:26, 4.11epoch/s, epoch=93, mse=264]
Train Mode: 47%|████▋ | 94/200 [00:23<00:25, 4.09epoch/s, epoch=93, mse=264]
Train Mode: 47%|████▋ | 94/200 [00:23<00:25, 4.09epoch/s, epoch=94, mse=265]
Train Mode: 48%|████▊ | 95/200 [00:23<00:25, 4.08epoch/s, epoch=94, mse=265]
Train Mode: 48%|████▊ | 95/200 [00:23<00:25, 4.08epoch/s, epoch=95, mse=261]
Train Mode: 48%|████▊ | 96/200 [00:23<00:26, 3.92epoch/s, epoch=95, mse=261]
Train Mode: 48%|████▊ | 96/200 [00:23<00:26, 3.92epoch/s, epoch=96, mse=264]
Train Mode: 48%|████▊ | 97/200 [00:23<00:26, 3.90epoch/s, epoch=96, mse=264]
Train Mode: 48%|████▊ | 97/200 [00:24<00:26, 3.90epoch/s, epoch=97, mse=272]
Train Mode: 49%|████▉ | 98/200 [00:24<00:26, 3.81epoch/s, epoch=97, mse=272]
Train Mode: 49%|████▉ | 98/200 [00:24<00:26, 3.81epoch/s, epoch=98, mse=266]
Train Mode: 50%|████▉ | 99/200 [00:24<00:25, 4.01epoch/s, epoch=98, mse=266]
Train Mode: 50%|████▉ | 99/200 [00:24<00:25, 4.01epoch/s, epoch=99, mse=256]
Train Mode: 50%|█████ | 100/200 [00:24<00:25, 3.99epoch/s, epoch=99, mse=256]
Train Mode: 50%|█████ | 100/200 [00:24<00:25, 3.99epoch/s, epoch=100, mse=259]
Train Mode: 50%|█████ | 101/200 [00:24<00:24, 4.01epoch/s, epoch=100, mse=259]
Train Mode: 50%|█████ | 101/200 [00:25<00:24, 4.01epoch/s, epoch=101, mse=261]
Train Mode: 51%|█████ | 102/200 [00:25<00:23, 4.11epoch/s, epoch=101, mse=261]
Train Mode: 51%|█████ | 102/200 [00:25<00:23, 4.11epoch/s, epoch=102, mse=260]
Train Mode: 52%|█████▏ | 103/200 [00:25<00:23, 4.10epoch/s, epoch=102, mse=260]
Train Mode: 52%|█████▏ | 103/200 [00:25<00:23, 4.10epoch/s, epoch=103, mse=264]
Train Mode: 52%|█████▏ | 104/200 [00:25<00:23, 4.15epoch/s, epoch=103, mse=264]
Train Mode: 52%|█████▏ | 104/200 [00:25<00:23, 4.15epoch/s, epoch=104, mse=259]
Train Mode: 52%|█████▎ | 105/200 [00:25<00:23, 4.12epoch/s, epoch=104, mse=259]
Train Mode: 52%|█████▎ | 105/200 [00:26<00:23, 4.12epoch/s, epoch=105, mse=256]
Train Mode: 53%|█████▎ | 106/200 [00:26<00:22, 4.12epoch/s, epoch=105, mse=256]
Train Mode: 53%|█████▎ | 106/200 [00:26<00:22, 4.12epoch/s, epoch=106, mse=252]
Train Mode: 54%|█████▎ | 107/200 [00:26<00:22, 4.20epoch/s, epoch=106, mse=252]
Train Mode: 54%|█████▎ | 107/200 [00:26<00:22, 4.20epoch/s, epoch=107, mse=250]
Train Mode: 54%|█████▍ | 108/200 [00:26<00:21, 4.30epoch/s, epoch=107, mse=250]
Train Mode: 54%|█████▍ | 108/200 [00:26<00:21, 4.30epoch/s, epoch=108, mse=246]
Train Mode: 55%|█████▍ | 109/200 [00:26<00:21, 4.33epoch/s, epoch=108, mse=246]
Train Mode: 55%|█████▍ | 109/200 [00:27<00:21, 4.33epoch/s, epoch=109, mse=254]
Train Mode: 55%|█████▌ | 110/200 [00:27<00:21, 4.27epoch/s, epoch=109, mse=254]
Train Mode: 55%|█████▌ | 110/200 [00:27<00:21, 4.27epoch/s, epoch=110, mse=250]
Train Mode: 56%|█████▌ | 111/200 [00:27<00:20, 4.27epoch/s, epoch=110, mse=250]
Train Mode: 56%|█████▌ | 111/200 [00:27<00:20, 4.27epoch/s, epoch=111, mse=257]
Train Mode: 56%|█████▌ | 112/200 [00:27<00:20, 4.38epoch/s, epoch=111, mse=257]
Train Mode: 56%|█████▌ | 112/200 [00:27<00:20, 4.38epoch/s, epoch=112, mse=256]
Train Mode: 56%|█████▋ | 113/200 [00:27<00:19, 4.44epoch/s, epoch=112, mse=256]
Train Mode: 56%|█████▋ | 113/200 [00:27<00:19, 4.44epoch/s, epoch=113, mse=257]
Train Mode: 57%|█████▋ | 114/200 [00:27<00:19, 4.35epoch/s, epoch=113, mse=257]
Train Mode: 57%|█████▋ | 114/200 [00:28<00:19, 4.35epoch/s, epoch=114, mse=258]
Train Mode: 57%|█████▊ | 115/200 [00:28<00:19, 4.27epoch/s, epoch=114, mse=258]
Train Mode: 57%|█████▊ | 115/200 [00:28<00:19, 4.27epoch/s, epoch=115, mse=254]
Train Mode: 58%|█████▊ | 116/200 [00:28<00:19, 4.32epoch/s, epoch=115, mse=254]
Train Mode: 58%|█████▊ | 116/200 [00:28<00:19, 4.32epoch/s, epoch=116, mse=243]
Train Mode: 58%|█████▊ | 117/200 [00:28<00:19, 4.33epoch/s, epoch=116, mse=243]
Train Mode: 58%|█████▊ | 117/200 [00:28<00:19, 4.33epoch/s, epoch=117, mse=241]
Train Mode: 59%|█████▉ | 118/200 [00:28<00:19, 4.31epoch/s, epoch=117, mse=241]
Train Mode: 59%|█████▉ | 118/200 [00:29<00:19, 4.31epoch/s, epoch=118, mse=249]
Train Mode: 60%|█████▉ | 119/200 [00:29<00:19, 4.26epoch/s, epoch=118, mse=249]
Train Mode: 60%|█████▉ | 119/200 [00:29<00:19, 4.26epoch/s, epoch=119, mse=248]
Train Mode: 60%|██████ | 120/200 [00:29<00:18, 4.34epoch/s, epoch=119, mse=248]
Train Mode: 60%|██████ | 120/200 [00:29<00:18, 4.34epoch/s, epoch=120, mse=254]
Train Mode: 60%|██████ | 121/200 [00:29<00:18, 4.36epoch/s, epoch=120, mse=254]
Train Mode: 60%|██████ | 121/200 [00:29<00:18, 4.36epoch/s, epoch=121, mse=243]
Train Mode: 61%|██████ | 122/200 [00:29<00:17, 4.39epoch/s, epoch=121, mse=243]
Train Mode: 61%|██████ | 122/200 [00:29<00:17, 4.39epoch/s, epoch=122, mse=239]
Train Mode: 62%|██████▏ | 123/200 [00:30<00:17, 4.34epoch/s, epoch=122, mse=239]
Train Mode: 62%|██████▏ | 123/200 [00:30<00:17, 4.34epoch/s, epoch=123, mse=238]
Train Mode: 62%|██████▏ | 124/200 [00:30<00:17, 4.32epoch/s, epoch=123, mse=238]
Train Mode: 62%|██████▏ | 124/200 [00:30<00:17, 4.32epoch/s, epoch=124, mse=237]
Train Mode: 62%|██████▎ | 125/200 [00:30<00:17, 4.24epoch/s, epoch=124, mse=237]
Train Mode: 62%|██████▎ | 125/200 [00:30<00:17, 4.24epoch/s, epoch=125, mse=244]
Train Mode: 63%|██████▎ | 126/200 [00:30<00:17, 4.28epoch/s, epoch=125, mse=244]
Train Mode: 63%|██████▎ | 126/200 [00:30<00:17, 4.28epoch/s, epoch=126, mse=252]
Train Mode: 64%|██████▎ | 127/200 [00:30<00:17, 4.28epoch/s, epoch=126, mse=252]
Train Mode: 64%|██████▎ | 127/200 [00:31<00:17, 4.28epoch/s, epoch=127, mse=249]
Train Mode: 64%|██████▍ | 128/200 [00:31<00:16, 4.25epoch/s, epoch=127, mse=249]
Train Mode: 64%|██████▍ | 128/200 [00:31<00:16, 4.25epoch/s, epoch=128, mse=244]
Train Mode: 64%|██████▍ | 129/200 [00:31<00:16, 4.30epoch/s, epoch=128, mse=244]
Train Mode: 64%|██████▍ | 129/200 [00:31<00:16, 4.30epoch/s, epoch=129, mse=230]
Train Mode: 65%|██████▌ | 130/200 [00:31<00:16, 4.33epoch/s, epoch=129, mse=230]
Train Mode: 65%|██████▌ | 130/200 [00:31<00:16, 4.33epoch/s, epoch=130, mse=238]
Train Mode: 66%|██████▌ | 131/200 [00:31<00:15, 4.32epoch/s, epoch=130, mse=238]
Train Mode: 66%|██████▌ | 131/200 [00:32<00:15, 4.32epoch/s, epoch=131, mse=230]
Train Mode: 66%|██████▌ | 132/200 [00:32<00:15, 4.31epoch/s, epoch=131, mse=230]
Train Mode: 66%|██████▌ | 132/200 [00:32<00:15, 4.31epoch/s, epoch=132, mse=234]
Train Mode: 66%|██████▋ | 133/200 [00:32<00:15, 4.35epoch/s, epoch=132, mse=234]
Train Mode: 66%|██████▋ | 133/200 [00:32<00:15, 4.35epoch/s, epoch=133, mse=229]
Train Mode: 67%|██████▋ | 134/200 [00:32<00:15, 4.38epoch/s, epoch=133, mse=229]
Train Mode: 67%|██████▋ | 134/200 [00:32<00:15, 4.38epoch/s, epoch=134, mse=223]
Train Mode: 68%|██████▊ | 135/200 [00:32<00:15, 4.30epoch/s, epoch=134, mse=223]
Train Mode: 68%|██████▊ | 135/200 [00:33<00:15, 4.30epoch/s, epoch=135, mse=229]
Train Mode: 68%|██████▊ | 136/200 [00:33<00:14, 4.34epoch/s, epoch=135, mse=229]
Train Mode: 68%|██████▊ | 136/200 [00:33<00:14, 4.34epoch/s, epoch=136, mse=237]
Train Mode: 68%|██████▊ | 137/200 [00:33<00:14, 4.29epoch/s, epoch=136, mse=237]
Train Mode: 68%|██████▊ | 137/200 [00:33<00:14, 4.29epoch/s, epoch=137, mse=229]
Train Mode: 69%|██████▉ | 138/200 [00:33<00:14, 4.27epoch/s, epoch=137, mse=229]
Train Mode: 69%|██████▉ | 138/200 [00:33<00:14, 4.27epoch/s, epoch=138, mse=222]
Train Mode: 70%|██████▉ | 139/200 [00:33<00:13, 4.37epoch/s, epoch=138, mse=222]
Train Mode: 70%|██████▉ | 139/200 [00:33<00:13, 4.37epoch/s, epoch=139, mse=223]
Train Mode: 70%|███████ | 140/200 [00:34<00:14, 4.18epoch/s, epoch=139, mse=223]
Train Mode: 70%|███████ | 140/200 [00:34<00:14, 4.18epoch/s, epoch=140, mse=218]
Train Mode: 70%|███████ | 141/200 [00:34<00:14, 4.20epoch/s, epoch=140, mse=218]
Train Mode: 70%|███████ | 141/200 [00:34<00:14, 4.20epoch/s, epoch=141, mse=223]
Train Mode: 71%|███████ | 142/200 [00:34<00:13, 4.25epoch/s, epoch=141, mse=223]
Train Mode: 71%|███████ | 142/200 [00:34<00:13, 4.25epoch/s, epoch=142, mse=217]
Train Mode: 72%|███████▏ | 143/200 [00:34<00:13, 4.25epoch/s, epoch=142, mse=217]
Train Mode: 72%|███████▏ | 143/200 [00:34<00:13, 4.25epoch/s, epoch=143, mse=222]
Train Mode: 72%|███████▏ | 144/200 [00:34<00:13, 4.27epoch/s, epoch=143, mse=222]
Train Mode: 72%|███████▏ | 144/200 [00:35<00:13, 4.27epoch/s, epoch=144, mse=214]
Train Mode: 72%|███████▎ | 145/200 [00:35<00:13, 3.96epoch/s, epoch=144, mse=214]
Train Mode: 72%|███████▎ | 145/200 [00:35<00:13, 3.96epoch/s, epoch=145, mse=226]
Train Mode: 73%|███████▎ | 146/200 [00:35<00:14, 3.72epoch/s, epoch=145, mse=226]
Train Mode: 73%|███████▎ | 146/200 [00:35<00:14, 3.72epoch/s, epoch=146, mse=230]
Train Mode: 74%|███████▎ | 147/200 [00:35<00:13, 3.83epoch/s, epoch=146, mse=230]
Train Mode: 74%|███████▎ | 147/200 [00:36<00:13, 3.83epoch/s, epoch=147, mse=220]
Train Mode: 74%|███████▍ | 148/200 [00:36<00:13, 3.82epoch/s, epoch=147, mse=220]
Train Mode: 74%|███████▍ | 148/200 [00:36<00:13, 3.82epoch/s, epoch=148, mse=215]
Train Mode: 74%|███████▍ | 149/200 [00:36<00:13, 3.72epoch/s, epoch=148, mse=215]
Train Mode: 74%|███████▍ | 149/200 [00:36<00:13, 3.72epoch/s, epoch=149, mse=225]
Train Mode: 75%|███████▌ | 150/200 [00:36<00:13, 3.70epoch/s, epoch=149, mse=225]
Train Mode: 75%|███████▌ | 150/200 [00:36<00:13, 3.70epoch/s, epoch=150, mse=237]
Train Mode: 76%|███████▌ | 151/200 [00:36<00:12, 3.85epoch/s, epoch=150, mse=237]
Train Mode: 76%|███████▌ | 151/200 [00:37<00:12, 3.85epoch/s, epoch=151, mse=221]
Train Mode: 76%|███████▌ | 152/200 [00:37<00:12, 3.98epoch/s, epoch=151, mse=221]
Train Mode: 76%|███████▌ | 152/200 [00:37<00:12, 3.98epoch/s, epoch=152, mse=216]
Train Mode: 76%|███████▋ | 153/200 [00:37<00:11, 3.96epoch/s, epoch=152, mse=216]
Train Mode: 76%|███████▋ | 153/200 [00:37<00:11, 3.96epoch/s, epoch=153, mse=221]
Train Mode: 77%|███████▋ | 154/200 [00:37<00:11, 4.07epoch/s, epoch=153, mse=221]
Train Mode: 77%|███████▋ | 154/200 [00:37<00:11, 4.07epoch/s, epoch=154, mse=222]
Train Mode: 78%|███████▊ | 155/200 [00:37<00:10, 4.19epoch/s, epoch=154, mse=222]
Train Mode: 78%|███████▊ | 155/200 [00:37<00:10, 4.19epoch/s, epoch=155, mse=211]
Train Mode: 78%|███████▊ | 156/200 [00:38<00:10, 4.28epoch/s, epoch=155, mse=211]
Train Mode: 78%|███████▊ | 156/200 [00:38<00:10, 4.28epoch/s, epoch=156, mse=209]
Train Mode: 78%|███████▊ | 157/200 [00:38<00:09, 4.32epoch/s, epoch=156, mse=209]
Train Mode: 78%|███████▊ | 157/200 [00:38<00:09, 4.32epoch/s, epoch=157, mse=208]
Train Mode: 79%|███████▉ | 158/200 [00:38<00:09, 4.22epoch/s, epoch=157, mse=208]
Train Mode: 79%|███████▉ | 158/200 [00:38<00:09, 4.22epoch/s, epoch=158, mse=208]
Train Mode: 80%|███████▉ | 159/200 [00:38<00:10, 4.08epoch/s, epoch=158, mse=208]
Train Mode: 80%|███████▉ | 159/200 [00:38<00:10, 4.08epoch/s, epoch=159, mse=206]
Train Mode: 80%|████████ | 160/200 [00:38<00:09, 4.17epoch/s, epoch=159, mse=206]
Train Mode: 80%|████████ | 160/200 [00:39<00:09, 4.17epoch/s, epoch=160, mse=211]
Train Mode: 80%|████████ | 161/200 [00:39<00:09, 4.19epoch/s, epoch=160, mse=211]
Train Mode: 80%|████████ | 161/200 [00:39<00:09, 4.19epoch/s, epoch=161, mse=205]
Train Mode: 81%|████████ | 162/200 [00:39<00:08, 4.26epoch/s, epoch=161, mse=205]
Train Mode: 81%|████████ | 162/200 [00:39<00:08, 4.26epoch/s, epoch=162, mse=199]
Train Mode: 82%|████████▏ | 163/200 [00:39<00:08, 4.26epoch/s, epoch=162, mse=199]
Train Mode: 82%|████████▏ | 163/200 [00:39<00:08, 4.26epoch/s, epoch=163, mse=202]
Train Mode: 82%|████████▏ | 164/200 [00:39<00:08, 4.25epoch/s, epoch=163, mse=202]
Train Mode: 82%|████████▏ | 164/200 [00:40<00:08, 4.25epoch/s, epoch=164, mse=205]
Train Mode: 82%|████████▎ | 165/200 [00:40<00:08, 4.30epoch/s, epoch=164, mse=205]
Train Mode: 82%|████████▎ | 165/200 [00:40<00:08, 4.30epoch/s, epoch=165, mse=203]
Train Mode: 83%|████████▎ | 166/200 [00:40<00:07, 4.27epoch/s, epoch=165, mse=203]
Train Mode: 83%|████████▎ | 166/200 [00:40<00:07, 4.27epoch/s, epoch=166, mse=202]
Train Mode: 84%|████████▎ | 167/200 [00:40<00:07, 4.31epoch/s, epoch=166, mse=202]
Train Mode: 84%|████████▎ | 167/200 [00:40<00:07, 4.31epoch/s, epoch=167, mse=203]
Train Mode: 84%|████████▍ | 168/200 [00:40<00:07, 4.31epoch/s, epoch=167, mse=203]
Train Mode: 84%|████████▍ | 168/200 [00:41<00:07, 4.31epoch/s, epoch=168, mse=200]
Train Mode: 84%|████████▍ | 169/200 [00:41<00:07, 4.35epoch/s, epoch=168, mse=200]
Train Mode: 84%|████████▍ | 169/200 [00:41<00:07, 4.35epoch/s, epoch=169, mse=202]
Train Mode: 85%|████████▌ | 170/200 [00:41<00:06, 4.36epoch/s, epoch=169, mse=202]
Train Mode: 85%|████████▌ | 170/200 [00:41<00:06, 4.36epoch/s, epoch=170, mse=205]
Train Mode: 86%|████████▌ | 171/200 [00:41<00:06, 4.36epoch/s, epoch=170, mse=205]
Train Mode: 86%|████████▌ | 171/200 [00:41<00:06, 4.36epoch/s, epoch=171, mse=208]
Train Mode: 86%|████████▌ | 172/200 [00:41<00:06, 4.39epoch/s, epoch=171, mse=208]
Train Mode: 86%|████████▌ | 172/200 [00:41<00:06, 4.39epoch/s, epoch=172, mse=218]
Train Mode: 86%|████████▋ | 173/200 [00:42<00:06, 4.23epoch/s, epoch=172, mse=218]
Train Mode: 86%|████████▋ | 173/200 [00:42<00:06, 4.23epoch/s, epoch=173, mse=219]
Train Mode: 87%|████████▋ | 174/200 [00:42<00:06, 4.18epoch/s, epoch=173, mse=219]
Train Mode: 87%|████████▋ | 174/200 [00:42<00:06, 4.18epoch/s, epoch=174, mse=207]
Train Mode: 88%|████████▊ | 175/200 [00:42<00:05, 4.19epoch/s, epoch=174, mse=207]
Train Mode: 88%|████████▊ | 175/200 [00:42<00:05, 4.19epoch/s, epoch=175, mse=205]
Train Mode: 88%|████████▊ | 176/200 [00:42<00:05, 4.29epoch/s, epoch=175, mse=205]
Train Mode: 88%|████████▊ | 176/200 [00:42<00:05, 4.29epoch/s, epoch=176, mse=203]
Train Mode: 88%|████████▊ | 177/200 [00:42<00:05, 4.26epoch/s, epoch=176, mse=203]
Train Mode: 88%|████████▊ | 177/200 [00:43<00:05, 4.26epoch/s, epoch=177, mse=206]
Train Mode: 89%|████████▉ | 178/200 [00:43<00:05, 4.29epoch/s, epoch=177, mse=206]
Train Mode: 89%|████████▉ | 178/200 [00:43<00:05, 4.29epoch/s, epoch=178, mse=200]
Train Mode: 90%|████████▉ | 179/200 [00:43<00:04, 4.36epoch/s, epoch=178, mse=200]
Train Mode: 90%|████████▉ | 179/200 [00:43<00:04, 4.36epoch/s, epoch=179, mse=197]
Train Mode: 90%|█████████ | 180/200 [00:43<00:04, 4.33epoch/s, epoch=179, mse=197]
Train Mode: 90%|█████████ | 180/200 [00:43<00:04, 4.33epoch/s, epoch=180, mse=195]
Train Mode: 90%|█████████ | 181/200 [00:43<00:04, 4.10epoch/s, epoch=180, mse=195]
Train Mode: 90%|█████████ | 181/200 [00:44<00:04, 4.10epoch/s, epoch=181, mse=199]
Train Mode: 91%|█████████ | 182/200 [00:44<00:04, 4.17epoch/s, epoch=181, mse=199]
Train Mode: 91%|█████████ | 182/200 [00:44<00:04, 4.17epoch/s, epoch=182, mse=201]
Train Mode: 92%|█████████▏| 183/200 [00:44<00:03, 4.30epoch/s, epoch=182, mse=201]
Train Mode: 92%|█████████▏| 183/200 [00:44<00:03, 4.30epoch/s, epoch=183, mse=200]
Train Mode: 92%|█████████▏| 184/200 [00:44<00:03, 4.29epoch/s, epoch=183, mse=200]
Train Mode: 92%|█████████▏| 184/200 [00:44<00:03, 4.29epoch/s, epoch=184, mse=202]
Train Mode: 92%|█████████▎| 185/200 [00:44<00:03, 4.30epoch/s, epoch=184, mse=202]
Train Mode: 92%|█████████▎| 185/200 [00:44<00:03, 4.30epoch/s, epoch=185, mse=199]
Train Mode: 93%|█████████▎| 186/200 [00:45<00:03, 4.35epoch/s, epoch=185, mse=199]
Train Mode: 93%|█████████▎| 186/200 [00:45<00:03, 4.35epoch/s, epoch=186, mse=198]
Train Mode: 94%|█████████▎| 187/200 [00:45<00:03, 4.33epoch/s, epoch=186, mse=198]
Train Mode: 94%|█████████▎| 187/200 [00:45<00:03, 4.33epoch/s, epoch=187, mse=201]
Train Mode: 94%|█████████▍| 188/200 [00:45<00:02, 4.26epoch/s, epoch=187, mse=201]
Train Mode: 94%|█████████▍| 188/200 [00:45<00:02, 4.26epoch/s, epoch=188, mse=195]
Train Mode: 94%|█████████▍| 189/200 [00:45<00:02, 4.23epoch/s, epoch=188, mse=195]
Train Mode: 94%|█████████▍| 189/200 [00:45<00:02, 4.23epoch/s, epoch=189, mse=194]
Train Mode: 95%|█████████▌| 190/200 [00:45<00:02, 4.27epoch/s, epoch=189, mse=194]
Train Mode: 95%|█████████▌| 190/200 [00:46<00:02, 4.27epoch/s, epoch=190, mse=192]
Train Mode: 96%|█████████▌| 191/200 [00:46<00:02, 4.31epoch/s, epoch=190, mse=192]
Train Mode: 96%|█████████▌| 191/200 [00:46<00:02, 4.31epoch/s, epoch=191, mse=191]
Train Mode: 96%|█████████▌| 192/200 [00:46<00:01, 4.36epoch/s, epoch=191, mse=191]
Train Mode: 96%|█████████▌| 192/200 [00:46<00:01, 4.36epoch/s, epoch=192, mse=195]
Train Mode: 96%|█████████▋| 193/200 [00:46<00:01, 4.36epoch/s, epoch=192, mse=195]
Train Mode: 96%|█████████▋| 193/200 [00:46<00:01, 4.36epoch/s, epoch=193, mse=196]
Train Mode: 97%|█████████▋| 194/200 [00:46<00:01, 4.19epoch/s, epoch=193, mse=196]
Train Mode: 97%|█████████▋| 194/200 [00:47<00:01, 4.19epoch/s, epoch=194, mse=193]
Train Mode: 98%|█████████▊| 195/200 [00:47<00:01, 4.15epoch/s, epoch=194, mse=193]
Train Mode: 98%|█████████▊| 195/200 [00:47<00:01, 4.15epoch/s, epoch=195, mse=192]
Train Mode: 98%|█████████▊| 196/200 [00:47<00:01, 3.97epoch/s, epoch=195, mse=192]
Train Mode: 98%|█████████▊| 196/200 [00:47<00:01, 3.97epoch/s, epoch=196, mse=195]
Train Mode: 98%|█████████▊| 197/200 [00:47<00:00, 3.77epoch/s, epoch=196, mse=195]
Train Mode: 98%|█████████▊| 197/200 [00:47<00:00, 3.77epoch/s, epoch=197, mse=206]
Train Mode: 99%|█████████▉| 198/200 [00:47<00:00, 3.90epoch/s, epoch=197, mse=206]
Train Mode: 99%|█████████▉| 198/200 [00:48<00:00, 3.90epoch/s, epoch=198, mse=200]
Train Mode: 100%|█████████▉| 199/200 [00:48<00:00, 4.01epoch/s, epoch=198, mse=200]
Train Mode: 100%|█████████▉| 199/200 [00:48<00:00, 4.01epoch/s, epoch=199, mse=198]
Train Mode: 100%|██████████| 200/200 [00:48<00:00, 4.09epoch/s, epoch=199, mse=198]
Train Mode: 100%|██████████| 200/200 [00:48<00:00, 4.11epoch/s, epoch=200, mse=188]
可视化
接下来我们可视化一下训练和验证过程中,损失(MSE)的变化情况。
以及我们最终预测结果和真实值之间的分布情况。
[ ]:
import matplotlib.pyplot as plt # 用来绘制曲线
from matplotlib.pyplot import figure
def plot_learning_curve(loss_record, title=''):
''' 绘制损失函数曲线 '''
total_steps = len(loss_record['train_loss'])
x_1 = range(total_steps)
x_2 = x_1[::len(loss_record['train_loss']) // len(loss_record['val_loss'])]
figure(figsize=(6, 4))
plt.plot(x_1, loss_record['train_loss'], c='tab:red', label='train')
plt.plot(x_2, loss_record['val_loss'], c='tab:cyan', label='val')
plt.ylim(0.0, 1000.)
plt.xlabel('Training steps')
plt.ylabel('MSE loss')
plt.title('Learning curve of {}'.format(title))
plt.legend()
plt.show()
def plot_pred(dv_set, model, device, lim=35., title="", preds=None, targets=None):
''' 绘制你的预测结果和真实结果之间的分布情况 '''
if preds is None or targets is None:
model.eval()
preds, targets = [], []
for x, y in dv_set:
x, y = x.to(device), y.to(device)
with torch.no_grad():
pred = model(x.float())
preds.append(pred.detach().cpu())
targets.append(y.detach().cpu())
preds = torch.cat(preds, dim=0).numpy()
targets = torch.cat(targets, dim=0).numpy()
figure(figsize=(5, 5))
plt.scatter(targets, preds, c='r', alpha=0.5)
plt.plot([-0.2, lim], [-0.2, lim], c='b')
plt.xlim(-0.2, lim)
plt.ylim(-0.2, lim)
plt.xlabel('ground truth value')
plt.ylabel('predicted value')
plt.title(title)
plt.show()
plot_learning_curve(loss_log,"DNN")
plot_pred(train_dataset,model,device,35., 'Ground Truth of Train v.s. Prediction')
plot_pred(val_dataset,model,device,35.,'Ground Truth of Val v.s. Prediction')
Train Mode: 11%|█ | 109/1000 [16:12<2:12:33, 8.93s/epoch, epoch=109, mse=795]
你需要做什么
基本要求
弄明白这些代码都在干什么
代码中有一处小错误,你能找出来吗(提示:构建数据部分的特征标准化)
中等要求
特征选择:我们使用了csv文件中提供的所有特征用来预测pm2.5,他们都是有用的吗?如果只用部分特征会不会更好?为什么呢?
参数调整:在主程序中,有一个config变量,它存储了我们的很多超参数,如learning rate,batch size等,调整他们会改进模型的精度吗。关于优化器(optimizer),我们使用的是SGD(随机梯度下降),有没有其它的优化器可以使用呢,这些优化器里的超参数(如momentum)需要怎么设置呢?
L2标准化:训练过程中,让模型的神经元的权值波动不要过大会提高模型的泛化性能,L2标准化是对模型神经元的权值的惩罚项。
终极要求
模型修改:更深的神经网络,更多的神经元会带来更好的效果吗?有没有什么办法对抗过拟合(如dropout,batchnorm)?
地理位置与时间:地理位置(lon,lat),以及时间信息(date)会提升PM2.5的估计精度吗?如果可以的话,该怎么使用它们呢?
其它:发挥你的想法,尽可能地减小测试集误差,并解释为什么