Эффективные способы загрузки модели в PyTorch
Загрузка модели PyTorch: основные методы и рекомендации
Наиболее универсальный и эффективный способ загрузить модель PyTorch - использование функции torch.load с явным указанием map_location. Это позволяет контролировать, на какое устройство будет загружена модель, и избежать типичных ошибок несовместимости.
import torch
import torchvision.models as models
# Предположим, модель сохранена как 'model.pth'
model = models.resnet18(pretrained=False)
# Загрузка на CPU (даже если модель сохранялась на GPU)
checkpoint = torch.load('model.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)
# Загрузка на конкретный GPU (cuda:0)
# checkpoint = torch.load('model.pth', map_location='cuda:0')Python torch load (загрузка модели pytorch)
Файл model.pth должен содержать словарь состояния (state_dict), сохранённый ранее через torch.save(model.state_dict(), ...). После загрузки словарь применяется к экземпляру модели с помощью load_state_dict.
Типичные проблемы:
- Ошибка "RuntimeError: Expected all tensors to be on the same device" - возникает, если часть весов загружена на CPU, а часть на GPU. Решение: всегда использовать map_location.
- Ошибка "ModuleNotFoundError" - если сохранённая модель использует пользовательские классы, которые не импортированы в сценарий загрузки. Решение: импортировать все необходимые классы перед вызовом torch.load.
- ValueError: too many dimensions - несоответствие архитектуры модели при загрузке state_dict. Решение: убедиться, что экземпляр модели имеет ту же структуру, что и при сохранении.
Как загрузить модель целиком (включая архитектуру) вместо state_dict?
Если модель сохранена целиком через torch.save(model, 'full_model.pth'), её можно загрузить напрямую:
model = torch.load('full_model.pth', map_location='cpu')
model.eval() # перевод в режим оценки
Этот вариант удобен для быстрого прототипирования, но создаёт жёсткую зависимость от кода модели. Рекомендуется для финального развёртывания использовать TorchScript или сохранение state_dict.
Как загрузить модель в режиме eval() с отключёнными dropout и batch_norm?
После загрузки всегда следует вызывать model.eval():
model = torch.load('model.pth', map_location='cpu')
model.eval()
# Теперь dropout и batch_norm работают в режиме инференса
Как загрузить только часть слоёв (например, для fine-tuning)?
Можно загрузить state_dict с параметром strict=False, чтобы игнорировать отсутствующие ключи:
pretrained_dict = torch.load('pretrained.pth', map_location='cpu')
model.load_state_dict(pretrained_dict, strict=False)
# Пропущенные слои останутся инициализированными случайно
Проблема: ключи словаря не совпадают из-за префикса (например, у модели, обёрнутой в DataParallel).
Решение: убрать префикс "module." вручную:
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
name = k[7:] if k.startswith('module.') else k # удаление 'module.'
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
Как загрузить модель из чекпоинта, содержащего оптимизатор и эпоху?
checkpoint = torch.load('checkpoint.pth', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
Это стандартный подход для продолжения обучения.
Как загрузить модель, сохранённую в формате TorchScript?
model = torch.jit.load('model_scripted.pt', map_location='cpu')
output = model(input_tensor)
TorchScript позволяет запускать модель без кода Python - удобно для продакшена.
Расширенные примеры загрузки модели PyTorch
Пример 1: Загрузка части весов с переименованием ключей (например, для переноса обучения с другой архитектуры)
import torch
import torch.nn as nn
# Определим простую модель
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 5)
# Создадим другой класс с другим именем слоя
class AnotherModel(nn.Module):
def __init__(self):
super().__init__()
self.fcA = nn.Linear(10, 20)
self.fcB = nn.Linear(20, 5)
simple = SimpleModel()
another = AnotherModel()
# Сохраняем state_dict SimpleModel
torch.save(simple.state_dict(), 'simple.pth')
# Загружаем и вручную заменяем имена ключей
checkpoint = torch.load('simple.pth', map_location='cpu')
mapping = {'fc1.weight': 'fcA.weight', 'fc1.bias': 'fcA.bias',
'fc2.weight': 'fcB.weight', 'fc2.bias': 'fcB.bias'}
adapted = {mapping.get(k, k): v for k, v in checkpoint.items()}
another.load_state_dict(adapted, strict=False)
print("Веса перенесены, пропущенные ключи проигнорированы.")
Веса перенесены, пропущенные ключи проигнорированы.
Пример 2: Загрузка модели с обработкой разных версий PyTorch (ключ _metadata)
checkpoint = torch.load('old_model.pth', map_location='cpu')
# Если сохранён словарь с метаданными (версия, производная информация)
if '_metadata' in checkpoint:
del checkpoint['_metadata'] # удаляем для совместимости
model.load_state_dict(checkpoint, strict=False)
Пример 3: Загрузка модели с динамическим созданием класса (если модель была сохранена целиком)
import torch
import sys
# Сохранённая модель может требовать определения класса
class CustomModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3)
# Перед torch.load класс должен быть доступен
model = torch.load('custom.pth', map_location='cpu')
print(type(model)) #
Пример 4: Загрузка модели и немедленное тестирование на случайных данных
import torch
import torchvision.models as models
model = models.resnet50(pretrained=False)
checkpoint = torch.load('resnet50_weights.pth', map_location='cpu')
model.load_state_dict(checkpoint)
model.eval()
# Генерация случайного тензора и прогон
input_tensor = torch.randn(1, 3, 224, 224)
with torch.no_grad():
output = model(input_tensor)
print("Форма выхода:", output.shape)
Форма выхода: torch.Size([1, 1000])
Пример 5: Загрузка нескольких чекпоинтов с циклическим сохранением (например, best и last)
import os
checkpoints = ['best_model.pth', 'last_model.pth']
for path in checkpoints:
if os.path.exists(path):
state = torch.load(path, map_location='cpu')
# Фильтр только state_dict
if isinstance(state, dict) and 'model_state_dict' in state:
print(f"{path}: эпоха {state['epoch']}")
else:
print(f"{path}: содержит только state_dict")
best_model.pth: эпоха 42 last_model.pth: эпоха 50
Пример 6: Загрузка модели с кастомным map_location, например, на 'cuda:1' с падением на CPU
try:
model = torch.load('model.pth', map_location='cuda:1')
except RuntimeError:
print("GPU cuda:1 недоступен, загружаем на CPU")
model = torch.load('model.pth', map_location='cpu')
Пример 7: Использование lambda в map_location для детерминированного устройства
checkpoint = torch.load('model.pth',
map_location=lambda storage, loc: storage.cuda(0) if loc == 'cuda:0' else storage)