Stochastic Weight Averaging with Enchanter¶
PyTorch v1.6 から Stochastic Weight Averaging (SWA)を行うクラスが導入されました。 このチュートリアルでは、SWAを行うためのRunnerの定義について紹介します。
データセットはMNISTを使います。データセットのダウンロードとMNIST用のCNNを実装しまうす。
from comet_ml import Experiment
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ExponentialLR
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torchvision import transforms
from torchvision.datasets import MNIST
import enchanter.addons as addons
from enchanter.tasks import ClassificationRunner
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 3),
addons.Mish(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3),
addons.Mish(),
nn.MaxPool2d(2)
)
self.fc = nn.Sequential(
nn.Linear(64*5*5, 512),
addons.Mish(),
nn.Linear(512, 10)
)
def forward(self, x):
out = self.conv(x)
out = out.view(-1, 64*5*5)
out = self.fc(out)
return out
train_ds = MNIST(
"./data",
train=True,
transform=transforms.Compose([
transforms.ToTensor()
])
)
test_ds = MNIST(
"./data",
train=False,
transform=transforms.Compose([
transforms.ToTensor()
])
)
train_loader = DataLoader(train_ds, batch_size=32)
test_loader = DataLoader(test_ds, batch_size=32)
データセットとモデルの準備が終わりました。次にSWALRに対応したRunnerを定義しましょう。 簡単のために ClassificationRunner を使いますが、BaseRunner を継承した方法も流れは同じです。
class SWALRRunner(ClassificationRunner):
def __init__(self, *args, **kwargs):
super(SWALRRunner, self).__init__(*args, **kwargs)
self.swa_model = AveragedModel(self.model)
self.swa_scheduler = SWALR(self.optimizer, swa_lr=0.05)
self.swa_start = 5
def update_scheduler(self, epoch: int) -> None:
if epoch > self.swa_start:
self.swa_model.update_parameters(self.model)
self.swa_scheduler.step()
else:
super(SWALRRunner, self).update_scheduler(epoch)
def train_end(self, outputs):
update_bn(self.loaders["train"], self.swa_model)
return super(SWALRRunner, self).train_end(outputs)
最後にRunnerを定義して実行します。
model = Model()
optimizer = optim.Adam(model.parameters())
runner = SWALRRunner(
model, optimizer, nn.CrossEntropyLoss(),
experiment=Experiment(),
scheduler=[
ExponentialLR(optimizer, gamma=0.9)
]
)
runner.add_loader("train", train_loader)
runner.add_loader("test", test_loader)
runner.train_config(epochs=10)
runner.run()