使用JojoGAN創(chuàng)建風(fēng)格化的面部圖
介紹
風(fēng)格遷移是神經(jīng)網(wǎng)絡(luò)的一個發(fā)展領(lǐng)域,它是一個非常有用的功能,可以集成到社交媒體和人工智能應(yīng)用程序中。幾個神經(jīng)網(wǎng)絡(luò)可以根據(jù)訓(xùn)練數(shù)據(jù)將圖像樣式映射和傳輸?shù)捷斎雸D像。在本文中,我們將研究 JojoGAN,以及僅使用一種參考樣式來訓(xùn)練和生成具有該樣式的任何圖像的過程。
JoJoGAN:One Shot Face Stylization
One Shot Face Stylization(一次性面部風(fēng)格化)可用于 AI 應(yīng)用程序、社交媒體過濾器、有趣的應(yīng)用程序和業(yè)務(wù)用例。隨著 AI 生成的圖像和視頻濾鏡的日益普及,以及它們在社交媒體和短視頻、圖像中的使用,一次性面部風(fēng)格化是一個有用的功能,應(yīng)用程序和社交媒體公司可以將其集成到最終產(chǎn)品中。
因此,讓我們來看看用于一次性生成人臉樣式的流行 GAN 架構(gòu)——JojoGAN。
JojoGAN 架構(gòu)
JojoGAN 是一種風(fēng)格遷移程序,可讓將人臉圖像的風(fēng)格遷移為另一種風(fēng)格。它通過GAN將參考風(fēng)格圖像反轉(zhuǎn)為近似的配對訓(xùn)練數(shù)據(jù),根據(jù)風(fēng)格化代碼生成真實的人臉圖像,并與參考風(fēng)格圖像相匹配。然后將該數(shù)據(jù)集用于微調(diào) StyleGAN,并且可以使用新的輸入圖像,JojoGAN 將根據(jù) GAN 反轉(zhuǎn)(inversion)將其轉(zhuǎn)換為該特定樣式。
JojoGAN 架構(gòu)和工作流程
JojoGAN 只需一種參考風(fēng)格即可在很短的時間內(nèi)(不到 1 分鐘)進(jìn)行訓(xùn)練,并生成高質(zhì)量的風(fēng)格化圖像。
JojoGan 的一些例子
JojoGAN 生成的風(fēng)格化圖像的一些示例:
風(fēng)格化的圖像可以在各種不同的輸入風(fēng)格上生成并且可以修改。
JojoGan 代碼深潛
讓我們看看 JojoGAN 生成風(fēng)格化人像的實現(xiàn)。有幾個預(yù)訓(xùn)練模型可用,它們可以在我們的風(fēng)格圖像上進(jìn)行訓(xùn)練,或者可以修改模型以在幾分鐘內(nèi)更改風(fēng)格。
JojoGAN 的設(shè)置和導(dǎo)入
克隆 JojoGAN 存儲庫并導(dǎo)入必要的庫。在 Google Colab 存儲中創(chuàng)建一些文件夾,用于存儲反轉(zhuǎn)代碼、樣式圖像和模型。
!git clone https://github.com/mchong6/JoJoGAN.git
%cd JoJoGAN
!pip install tqdm gdown scikit-learn==0.22 scipy lpips dlib opencv-python wandb
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
import torch
torch.backends.cudnn.benchmark = True
from torchvision import transforms, utils
from util import *
from PIL import Image
import math
import random
import os
import numpy
from torch import nn, autograd, optim
from torch.nn import functional
from tqdm import tqdm
import wandb
from model import *
from e4e_projection import projection
from google.colab import files
from copy import deepcopy
from pydrive.a(chǎn)uth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
模型文件
使用 Pydrive 下載模型文件。一組驅(qū)動器 ID 可用于預(yù)訓(xùn)練模型。這些預(yù)訓(xùn)練模型可用于隨時隨地生成風(fēng)格化圖像,并具有不同的準(zhǔn)確度。之后,可以訓(xùn)練用戶創(chuàng)建的模型。
#Download models
#optionally enable downloads with pydrive in order to authenticate and avoid drive download limits.
download_with_pydrive = True
device = 'cuda' #['cuda', 'cpu']
!wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
!bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2
!mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat
%matplotlib inline
drive_ids = {
"stylegan2-ffhq-config-f.pt": "1Yr7KuD959btpmcKGAUsbAk5rPjX2MytK",
"e4e_ffhq_encode.pt": "1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7",
"restyle_psp_ffhq_encode.pt": "1nbxCIVw9H3YnQsoIPykNEFwWJnHVHlVd",
"arcane_caitlyn.pt": "1gOsDTiTPcENiFOrhmkkxJcTURykW1dRc",
"arcane_caitlyn_preserve_color.pt": "1cUTyjU-q98P75a8THCaO545RTwpVV-aH",
"arcane_jinx_preserve_color.pt": "1jElwHxaYPod5Itdy18izJk49K1nl4ney",
"arcane_jinx.pt": "1quQ8vPjYpUiXM4k1_KIwP4EccOefPpG_",
"arcane_multi_preserve_color.pt": "1enJgrC08NpWpx2XGBmLt1laimjpGCyfl",
"arcane_multi.pt": "15V9s09sgaw-zhKp116VHigf5FowAy43f",
"sketch_multi.pt": "1GdaeHGBGjBAFsWipTL0y-ssUiAqk8AxD",
"disney.pt": "1zbE2upakFUAx8ximYnLofFwfT8MilqJA",
"disney_preserve_color.pt": "1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi",
"jojo.pt": "13cR2xjIBj8Ga5jMO7gtxzIJj2PDsBYK4",
"jojo_preserve_color.pt": "1ZRwYLRytCEKi__eT2Zxv1IlV6BGVQ_K2",
"jojo_yasuho.pt": "1grZT3Gz1DLzFoJchAmoj3LoM9ew9ROX_",
"jojo_yasuho_preserve_color.pt": "1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_L",
"art.pt": "1a0QDEHwXQ6hE_FcYEyNMuv5r5UnRQLKT",
}
# from StyelGAN-NADA
class Downloader(object):
def __init__(self, use_pydrive):
self.use_pydrive = use_pydrive
if self.use_pydrive:
self.a(chǎn)uthenticate()
def authenticate(self):
auth.a(chǎn)uthenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
self.drive = GoogleDrive(gauth)
def download_file(self, file_name):
file_dst = os.path.join('models', file_name)
file_id = drive_ids[file_name]
if not os.path.exists(file_dst):
print(f'Downloading {file_name}')
if self.use_pydrive:
downloaded = self.drive.CreateFile({'id':file_id})
downloaded.FetchMetadata(fetch_all=True)
downloaded.GetContentFile(file_dst)
else:
!gdown --id $file_id -O $file_dst
downloader = Downloader(download_with_pydrive)
downloader.download_file('stylegan2-ffhq-config-f.pt')
downloader.download_file('e4e_ffhq_encode.pt')
加載生成器
加載原始和微調(diào)生成器。設(shè)置用于調(diào)整圖像大小和規(guī)范化圖像的 transforms。
latent_dim = 512
# Load original generator
original_generator = Generator(1024, latent_dim, 8, 2).to(device)
ckpt = torch.load('models/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
original_generator.load_state_dict(ckpt["g_ema"], strict=False)
mean_latent = original_generator.mean_latent(10000)
# to be finetuned generator
generator = deepcopy(original_generator)
transform = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
輸入圖像
設(shè)置輸入圖像位置。對齊和裁剪面并重新設(shè)置映射的樣式。
#image to the test_input directory and put the name here
filename = 'face.jpeg' #@param {type:"string"}
filepath = f'test_input/{filename}'
name = strip_path_extension(filepath)+'.pt'
# aligns and crops face
aligned_face = align_face(filepath)
# my_w = restyle_projection(aligned_face, name, device, n_iters=1).unsqueeze(0)
my_w = projection(aligned_face, name, device).unsqueeze(0)
預(yù)訓(xùn)練圖
選擇預(yù)訓(xùn)練好的圖類型,選擇不保留顏色的檢查點(diǎn),效果更好。
plt.rcParams['figure.dpi'] = 150
pretrained = 'sketch_multi' #['art', 'arcane_multi', 'sketch_multi', 'arcane_jinx', 'arcane_caitlyn', 'jojo_yasuho', 'jojo', 'disney']
#Preserve color tries to preserve color of original image by limiting family of allowable transformations.
if preserve_color:
ckpt = f'{pretrained}_preserve_color.pt'
else:
ckpt = f'{pretrained}.pt'
生成結(jié)果
加載檢查點(diǎn)和生成器并設(shè)置種子值,然后開始生成風(fēng)格化圖像。用于 Elon Musk 的輸入圖像將根據(jù)圖類型進(jìn)行風(fēng)格化。
#Generate results
n_sample = 5#{type:"number"}
seed = 3000 #{type:"number"}
torch.manual_seed(seed)
with torch.no_grad():
generator.eval()
z = torch.randn(n_sample, latent_dim, device=device)
original_sample = original_generator([z], truncation=0.7, truncation_latent=mean_latent)
sample = generator([z], truncation=0.7, truncation_latent=mean_latent)
original_my_sample = original_generator(my_w, input_is_latent=True)
my_sample = generator(my_w, input_is_latent=True)
# display reference images
if pretrained == 'arcane_multi':
style_path = f'style_images_aligned/arcane_jinx.png'
elif pretrained == 'sketch_multi':
style_path = f'style_images_aligned/sketch.png'
else:
style_path = f'style_images_aligned/{pretrained}.png'
style_image = transform(Image.open(style_path)).unsqueeze(0).to(device)
face = transform(aligned_face).unsqueeze(0).to(device)
my_output = torch.cat([style_image, face, my_sample], 0)
生成的結(jié)果
結(jié)果生成為預(yù)先訓(xùn)練的類型“Jojo”,看起來相當(dāng)準(zhǔn)確。
現(xiàn)在讓我們看一下在自創(chuàng)樣式上訓(xùn)練 GAN。
使用你的風(fēng)格圖像進(jìn)行訓(xùn)練
選擇一些面部圖,甚至創(chuàng)建一些自己的面部圖并加載這些圖像以訓(xùn)練 GAN,并設(shè)置路徑。裁剪和對齊人臉并執(zhí)行 GAN 反轉(zhuǎn)。
names = ['1.jpg', '2.jpg', '3.jpg']
targets = []
latents = []
for name in names:
style_path = os.path.join('style_images', name)
assert os.path.exists(style_path), f"{style_path} does not exist。
name = strip_path_extension(name)
# crop and align the face
style_aligned_path = os.path.join('style_images_aligned', f'{name}.png')
if not os.path.exists(style_aligned_path):
style_aligned = align_face(style_path)
style_aligned.save(style_aligned_path)
else:
style_aligned = Image.open(style_aligned_path).convert('RGB')
# GAN invert
style_code_path = os.path.join('inversion_codes', f'{name}.pt')
if not os.path.exists(style_code_path):
latent = projection(style_aligned, style_code_path, device)
else:
latent = torch.load(style_code_path)['latent']
latents.a(chǎn)ppend(latent.to(device))
targets = torch.stack(targets, 0)
latents = torch.stack(latents, 0)
微調(diào) StyleGAN
通過調(diào)整 alpha、顏色保留和設(shè)置迭代次數(shù)來微調(diào) StyleGAN。加載感知損失的鑒別器并重置生成器。
#Finetune StyleGAN
#alpha controls the strength of the style
alpha = 1.0 # min:0, max:1, step:0.1
alpha = 1-alpha
#preserve color of original image by limiting family of allowable transformations
preserve_color = False
#Number of finetuning steps.
num_iter = 300
#Log training on wandb and interval for image logging
use_wandb = False
log_interval = 50
if use_wandb:
wandb.init(project="JoJoGAN")
config = wandb.config
config.num_iter = num_iter
config.preserve_color = preserve_color
wandb.log(
{"Style reference": [wandb.Image(transforms.ToPILImage()(target_im))]},
step=0)
# load discriminator for perceptual loss
discriminator = Discriminator(1024, 2).eval().to(device)
ckpt = torch.load('models/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
discriminator.load_state_dict(ckpt["d"], strict=False)
# reset generator
del generator
generator = deepcopy(original_generator)
g_optim = optim.Adam(generator.parameters(), lr=2e-3, betas=(0, 0.99))
訓(xùn)練生成器從潛在空間生成圖像,并優(yōu)化損失。
if preserve_color:
id_swap = [9,11,15,16,17]
z = range(numiter)
for idx in tqdm( z):
mean_w = generator.get_latent(torch.randn([latents.size(0), latent_dim]).to(device)).unsqueeze(1).repeat(1, generator.n_latent, 1)
in_latent = latents.clone()
in_latent[:, id_swap] = alpha*latents[:, id_swap] + (1-alpha*mean_w[:, id_swap]
img = generator(in_latent, input_is_latent=True)
with torch.no_grad():
real_feat = discriminator(targets)
fake_feat = discriminator(img)
loss = sum([functional.l1_loss(a, b) for a, b in zip(fake_feat, real_feat)])/len(fake_feat)
if use_wandb:
wandb.log({"loss": loss}, step=idx)
if idx % log_interval == 0:
generator.eval()
my_sample = generator(my_w, input_is_latent=True)
generator.train()
wandb.log(
{"Current stylization": [wandb.Image(my_sample)]},
step=idx)
g_optim.zero_grad()
loss.backward()
g_optim.step()
使用 JojoGAN 生成結(jié)果
現(xiàn)在生成結(jié)果。下面已經(jīng)為原始圖像和示例圖像生成了結(jié)果以進(jìn)行比較。
#Generate resultsn_sample = 5
seed = 3000
torch.manual_seed(seed)
with torch.no_grad():
generator.eval()
z = torch.randn(n_sample, latent_dim, device=device)
original_sample = original_generator([z], truncation=0.7, truncation_latent=mean_latent)
sample = generator([z], truncation=0.7, truncation_latent=mean_latent)
original_my_sample = original_generator(my_w, input_is_latent=True)
my_sample = generator(my_w, input_is_latent=True)
# display reference images
style_images = []
for name in names:
style_path = f'style_images_aligned/{strip_path_extension(name)}.png'
style_image = transform(Image.open(style_path))
style_images.a(chǎn)ppend(style_image)
face = transform(aligned_face).to(device).unsqueeze(0)
style_images = torch.stack(style_images, 0).to(device)
my_output = torch.cat([face, my_sample], 0)
output = torch.cat([original_sample, sample], 0)
生成的結(jié)果
現(xiàn)在,你可以使用 JojoGAN 生成你自己風(fēng)格的圖像。結(jié)果令人印象深刻,但可以通過調(diào)整訓(xùn)練方法和訓(xùn)練圖像中的更多特征來進(jìn)一步改進(jìn)。
結(jié)論
JojoGAN 能夠以快速有效的方式準(zhǔn)確地映射和遷移用戶定義的樣式。關(guān)鍵要點(diǎn)是:
· JojoGAN 可以只用一種風(fēng)格進(jìn)行訓(xùn)練,以輕松映射并創(chuàng)建任何面部的風(fēng)格化圖
· JojoGAN 非?焖儆行,可以在不到一分鐘的時間內(nèi)完成訓(xùn)練
· 結(jié)果非常準(zhǔn)確,類似于逼真的肖像
· JojoGAN 可以輕松微調(diào)和修改,使其適用于 AI 應(yīng)用程序
因此,無論風(fēng)格類型、形狀和顏色如何,JojoGAN 都是用于風(fēng)格轉(zhuǎn)移的理想神經(jīng)網(wǎng)絡(luò),因此可以成為各種社交媒體應(yīng)用程序和 AI 應(yīng)用程序中非常有用的功能。
原文標(biāo)題 : 使用JojoGAN創(chuàng)建風(fēng)格化的面部圖

請輸入評論內(nèi)容...
請輸入評論/評論長度6~500個字
最新活動更多
推薦專題
- 1 UALink規(guī)范發(fā)布:挑戰(zhàn)英偉達(dá)AI統(tǒng)治的開始
- 2 北電數(shù)智主辦酒仙橋論壇,探索AI產(chǎn)業(yè)發(fā)展新路徑
- 3 降薪、加班、裁員三重暴擊,“AI四小龍”已折戟兩家
- 4 “AI寒武紀(jì)”爆發(fā)至今,五類新物種登上歷史舞臺
- 5 國產(chǎn)智駕迎戰(zhàn)特斯拉FSD,AI含量差幾何?
- 6 光計算迎來商業(yè)化突破,但落地仍需時間
- 7 東陽光:2024年扭虧、一季度凈利大增,液冷疊加具身智能打開成長空間
- 8 地平線自動駕駛方案解讀
- 9 封殺AI“照騙”,“淘寶們”終于不忍了?
- 10 優(yōu)必選:營收大增主靠小件,虧損繼續(xù)又逢關(guān)稅,能否乘機(jī)器人東風(fēng)翻身?