风格迁移

使用卷积神经网络迁移图片风格

浏览量:6012

原图

In [1]:
from torchvision.models import vgg19
from torchvision import transforms
from PIL import Image
from matplotlib import pyplot as plt
import torch
import torch.optim as optim
import numpy as np
In [2]:
model=vgg19(pretrained=True)
In [3]:
for i in model.parameters():
    i.requires_grad_(False)
In [4]:
def load_img(path, max_size=400,shape=None):
    img = Image.open(path).convert('RGB')
    if(max(img.size)) > max_size:size = max_size
    else:size=max(img.size)
    if shape is not None:
        size = shape
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))
    ])
    img = transform(img)[:3,:,:].unsqueeze(0)
    return img
In [5]:
content=load_img('QQ图片20220207183552.jpg')
style = load_img('201609281450067826.jpg', shape=content.shape[-2:])
In [6]:
model.parameters
Out[6]:
<bound method Module.parameters of VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace=True)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace=True)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace=True)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace=True)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace=True)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace=True)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)>
In [7]:
def get_features(img):
    features = {}
    layers = {'0':'conv1_1',
                '5':'conv2_1',
                '10':'conv3_1',
                '19':'conv4_1',
                '21':'conv4_2',    #content层
                '28':'conv5_1'}
    num=0
    for name,layer in model._modules.items():
        n=0
        if name=='features':
            for i in layer.modules():
                if n==0:n+=1;continue
                img=i(img)
                if n==1 or n==6 or n==11 or n==20 or n==22 or n==29:
                    features[str(n-1)]=img
                n+=1    
        break
    return features
In [8]:
content_features=get_features(content)
style_features=get_features(style)
In [9]:
def gram_matrix(tensor):
    _, d, h, w = tensor.size()
    tensor = tensor.view(d, h*w)
    gram = torch.mm(tensor, tensor.t())
    return gram   
In [10]:
style_grams = {layer:gram_matrix(style_features[layer]) for layer in style_features}
In [11]:
target = content.clone().requires_grad_(True)
In [12]:
def im_convert(tensor):
    img = tensor.clone().detach()
    img = img.numpy().squeeze()
    img = img.transpose(1,2,0)
    img = img * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    img = img.clip(0,1)
    return img
In [13]:
'''定义不同层的权重'''
style_weights = {
    '0': 1,
    '5': 0.8,
    '10': 0.5,
    '19': 0.3,
    '28': 0.1,
}
'''定义2种损失对应的权重'''
content_weight = 1
style_weight = 1e6
In [14]:
show_every = 100
optimizer = optim.Adam([target], lr=0.003)
steps =5000
 
for ii in range(steps):
    target_features = get_features(target)
    content_loss = torch.mean((content_features['21'] - target_features['21'])**2)
    style_loss = 0
    '''加上每一层的gram_matrix矩阵的损失'''
    for layer in style_weights:
        target_feature = target_features[layer]
        target_gram = gram_matrix(target_feature)
        _, d, h, w = target_feature.shape
        style_gram = style_grams[layer]
        layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
        style_loss += layer_style_loss/(d*h*w)     #加到总的style_loss里,除以大小
    total_loss = content_weight * content_loss + style_weight * style_loss
    print('Total Loss:',total_loss.item())
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if ii % show_every == 0 :
        print('Total Loss:',total_loss.item())
        plt.imshow(im_convert(target))
        plt.show()
In [15]:
plt.imshow(im_convert(target))
plt.show()