Лучший способ сохранить подготовленную модель в PyTorch?
Я искал альтернативные способы сохранения обученной модели в PyTorch. До сих пор я нашел две альтернативы.
Я столкнулся с этим обсуждением , где подход 2 рекомендуется для подхода 1.
Мой вопрос в том, почему предпочтительнее второй подход? Это только потому, что torch.nn модули имеют эти две функции, и нам рекомендуется их использовать?
Ответы
Ответ 1
Я нашел эту страницу в своем реестре github, я просто вставляю содержимое здесь.
Рекомендуемый подход для сохранения модели
Существует два основных подхода к сериализации и восстановлению модели.
Первый (рекомендуется) сохраняет и загружает только параметры модели:
torch.save(the_model.state_dict(), PATH)
Затем позже:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
Второй сохраняет и загружает всю модель:
torch.save(the_model, PATH)
Затем позже:
the_model = torch.load(PATH)
Однако в этом случае сериализованные данные привязаны к определенным классам
и точная структура каталогов, поэтому она может разрываться по-разному, когда
используется в других проектах или после некоторых серьезных рефакторов.
Ответ 2
Это зависит от того, что вы хотите сделать.
Случай №1: Сохраните модель, чтобы использовать ее самостоятельно для вывода: вы сохраняете модель, вы ее восстанавливаете, а затем вы меняете модель на режим оценки. Это делается потому, что у вас обычно есть BatchNorm
и Dropout
которые по умолчанию находятся в режиме поезда при построении:
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
Случай №2: Сохраните модель, чтобы возобновить обучение позже: если вам нужно продолжить обучение модели, которую вы собираетесь сохранить, вам нужно сохранить больше, чем просто модель. Вам также нужно сохранить состояние оптимизатора, эпох, очков и т.д. Вы бы сделали это следующим образом:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)
Чтобы возобновить обучение, вы будете делать такие вещи, как: state = torch.load(filepath)
, а затем, чтобы восстановить состояние каждого отдельного объекта, примерно так:
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
Поскольку вы возобновляете обучение, НЕ model.eval()
после восстановления состояний при загрузке.
Случай №3: модель, которую будет использовать кто-то другой, не имеющий доступа к вашему коду: в Tensorflow вы можете создать файл .pb
который определяет как архитектуру, так и вес модели. Это очень удобно, особенно при использовании Tensorflow serve
. Аналогичный способ сделать это в Питорхе:
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
Этот путь до сих пор не является доказательством пули, и поскольку pytorch все еще претерпевает большие изменения, я бы не рекомендовал его.
Ответ 3
Библиотека Python pickle реализует двоичные протоколы для сериализации и десериализации объекта Python.
Когда вы import torch
(или когда вы используете PyTorch), он import pickle
для вас, и вам не нужно pickle.dump()
вызывать pickle.dump()
и pickle.load()
, которые являются методами для сохранения и загрузки объекта.
Фактически torch.save()
и torch.load()
pickle.load()
для вас функции pickle.dump()
и pickle.load()
.
state_dict
другой упомянутый ответ заслуживает еще несколько замечаний.
Какой state_dict
у нас внутри PyTorch? На самом деле существует два state_dict
.
Модель PyTorch - torch.nn.Module
У torch.nn.Module
есть model.parameters()
для получения изучаемых параметров (w и b). Эти обучаемые параметры, однажды установленные случайным образом, будут обновляться с течением времени по мере нашего изучения. Изучаемые параметры - это первый state_dict
.
Второй state_dict
- это state_dict
состояния оптимизатора. Вы помните, что оптимизатор используется для улучшения наших усваиваемых параметров. Но оптимизатор state_dict
исправлен. Там нечему учиться.
Поскольку объекты state_dict
являются словарями Python, их можно легко сохранять, обновлять, изменять и восстанавливать, добавляя большую модульность моделям и оптимизаторам PyTorch.
Давайте создадим супер простую модель, чтобы объяснить это:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Этот код выведет следующее:
Model state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
Обратите внимание, что это минимальная модель. Вы можете попробовать добавить стек последовательных
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Обратите внимание, что только слои с усваиваемыми параметрами (сверточные слои, линейные слои и т.д.) И зарегистрированными буферами (пакетные слои) имеют записи в модели state_dict
.
Непознаваемые вещи принадлежат объекту оптимизатора state_dict
, который содержит информацию о состоянии оптимизатора, а также используемые гиперпараметры.
Остальная часть истории такая же; на этапе вывода (это этап, когда мы используем модель после обучения) для прогнозирования; мы делаем прогноз на основе параметров, которые мы узнали. Поэтому для вывода нам просто нужно сохранить параметры model.state_dict()
.
torch.save(model.state_dict(), filepath)
И использовать позже model.load_state_dict (torch.load(filepath)) model.eval()
Примечание: не забудьте model.eval()
последнюю строку model.eval()
это важно после загрузки модели.
Также не пытайтесь сохранить torch.save(model.parameters(), filepath)
. model.parameters()
это просто объект генератора.
С другой стороны, torch.save(model, filepath)
сохраняет сам объект модели, но имейте в виду, что модель не имеет оптимизатора state_dict
. Посмотрите другой отличный ответ @Jadiel de Armas, чтобы сохранить информацию о состоянии оптимизатора.
Ответ 4
Общепринятым соглашением PyTorch является сохранение моделей с использованием расширения файлов .pt или .pth.
Сохранить/загрузить всю модель Сохранить:
path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)
Нагрузка:
Класс модели должен быть определен где-то
model = torch.load(PATH)
model.eval()