Эффективные способы загрузки модели в PyTorch

Раздел: Машинное обучение -> Работа с PyTorch

Загрузка модели PyTorch: основные методы и рекомендации

Наиболее универсальный и эффективный способ загрузить модель PyTorch - использование функции torch.load с явным указанием map_location. Это позволяет контролировать, на какое устройство будет загружена модель, и избежать типичных ошибок несовместимости.

import torch
import torchvision.models as models

# Предположим, модель сохранена как 'model.pth'
model = models.resnet18(pretrained=False)

# Загрузка на CPU (даже если модель сохранялась на GPU)
checkpoint = torch.load('model.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)

# Загрузка на конкретный GPU (cuda:0)
# checkpoint = torch.load('model.pth', map_location='cuda:0')

Python torch load (загрузка модели pytorch)

Файл model.pth должен содержать словарь состояния (state_dict), сохранённый ранее через torch.save(model.state_dict(), ...). После загрузки словарь применяется к экземпляру модели с помощью load_state_dict.

Типичные проблемы:

  • Ошибка "RuntimeError: Expected all tensors to be on the same device" - возникает, если часть весов загружена на CPU, а часть на GPU. Решение: всегда использовать map_location.
  • Ошибка "ModuleNotFoundError" - если сохранённая модель использует пользовательские классы, которые не импортированы в сценарий загрузки. Решение: импортировать все необходимые классы перед вызовом torch.load.
  • ValueError: too many dimensions - несоответствие архитектуры модели при загрузке state_dict. Решение: убедиться, что экземпляр модели имеет ту же структуру, что и при сохранении.

Как загрузить модель целиком (включая архитектуру) вместо state_dict?

Если модель сохранена целиком через torch.save(model, 'full_model.pth'), её можно загрузить напрямую:

model = torch.load('full_model.pth', map_location='cpu')
model.eval()  # перевод в режим оценки

Этот вариант удобен для быстрого прототипирования, но создаёт жёсткую зависимость от кода модели. Рекомендуется для финального развёртывания использовать TorchScript или сохранение state_dict.

Как загрузить модель в режиме eval() с отключёнными dropout и batch_norm?

После загрузки всегда следует вызывать model.eval():

model = torch.load('model.pth', map_location='cpu')
model.eval()
# Теперь dropout и batch_norm работают в режиме инференса

Как загрузить только часть слоёв (например, для fine-tuning)?

Можно загрузить state_dict с параметром strict=False, чтобы игнорировать отсутствующие ключи:

pretrained_dict = torch.load('pretrained.pth', map_location='cpu')
model.load_state_dict(pretrained_dict, strict=False)
# Пропущенные слои останутся инициализированными случайно

Проблема: ключи словаря не совпадают из-за префикса (например, у модели, обёрнутой в DataParallel).

Решение: убрать префикс "module." вручную:

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    name = k[7:] if k.startswith('module.') else k  # удаление 'module.'
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)

Как загрузить модель из чекпоинта, содержащего оптимизатор и эпоху?

checkpoint = torch.load('checkpoint.pth', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']

Это стандартный подход для продолжения обучения.

Как загрузить модель, сохранённую в формате TorchScript?

model = torch.jit.load('model_scripted.pt', map_location='cpu')
output = model(input_tensor)

TorchScript позволяет запускать модель без кода Python - удобно для продакшена.

Расширенные примеры загрузки модели PyTorch

Пример 1: Загрузка части весов с переименованием ключей (например, для переноса обучения с другой архитектуры)

Пример
import torch
import torch.nn as nn

# Определим простую модель
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 5)

# Создадим другой класс с другим именем слоя
class AnotherModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fcA = nn.Linear(10, 20)
        self.fcB = nn.Linear(20, 5)

simple = SimpleModel()
another = AnotherModel()

# Сохраняем state_dict SimpleModel
torch.save(simple.state_dict(), 'simple.pth')

# Загружаем и вручную заменяем имена ключей
checkpoint = torch.load('simple.pth', map_location='cpu')
mapping = {'fc1.weight': 'fcA.weight', 'fc1.bias': 'fcA.bias',
           'fc2.weight': 'fcB.weight', 'fc2.bias': 'fcB.bias'}
adapted = {mapping.get(k, k): v for k, v in checkpoint.items()}
another.load_state_dict(adapted, strict=False)

print("Веса перенесены, пропущенные ключи проигнорированы.")
Веса перенесены, пропущенные ключи проигнорированы.

Пример 2: Загрузка модели с обработкой разных версий PyTorch (ключ _metadata)

Пример
checkpoint = torch.load('old_model.pth', map_location='cpu')
# Если сохранён словарь с метаданными (версия, производная информация)
if '_metadata' in checkpoint:
    del checkpoint['_metadata']  # удаляем для совместимости
model.load_state_dict(checkpoint, strict=False)

Пример 3: Загрузка модели с динамическим созданием класса (если модель была сохранена целиком)

Пример
import torch
import sys

# Сохранённая модель может требовать определения класса
class CustomModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, 3)

# Перед torch.load класс должен быть доступен
model = torch.load('custom.pth', map_location='cpu')
print(type(model))  # 

Пример 4: Загрузка модели и немедленное тестирование на случайных данных

Пример
import torch
import torchvision.models as models

model = models.resnet50(pretrained=False)
checkpoint = torch.load('resnet50_weights.pth', map_location='cpu')
model.load_state_dict(checkpoint)
model.eval()

# Генерация случайного тензора и прогон
input_tensor = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = model(input_tensor)
print("Форма выхода:", output.shape)
Форма выхода: torch.Size([1, 1000])

Пример 5: Загрузка нескольких чекпоинтов с циклическим сохранением (например, best и last)

Пример
import os

checkpoints = ['best_model.pth', 'last_model.pth']
for path in checkpoints:
    if os.path.exists(path):
        state = torch.load(path, map_location='cpu')
        # Фильтр только state_dict
        if isinstance(state, dict) and 'model_state_dict' in state:
            print(f"{path}: эпоха {state['epoch']}")
        else:
            print(f"{path}: содержит только state_dict")
best_model.pth: эпоха 42
last_model.pth: эпоха 50

Пример 6: Загрузка модели с кастомным map_location, например, на 'cuda:1' с падением на CPU

Пример
try:
    model = torch.load('model.pth', map_location='cuda:1')
except RuntimeError:
    print("GPU cuda:1 недоступен, загружаем на CPU")
    model = torch.load('model.pth', map_location='cpu')

Пример 7: Использование lambda в map_location для детерминированного устройства

Пример
checkpoint = torch.load('model.pth',
                        map_location=lambda storage, loc: storage.cuda(0) if loc == 'cuda:0' else storage)

Загрузка модели PyTorch - comments

En
Python torch load (python)