您的位置:首頁 >聚焦 > 帶你入門擴(kuò)散模型:DDPM 2022-11-22 19:24:29 來源:程序員客棧 點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師”(資料圖)設(shè)為星標(biāo),干貨直達(dá)!?“What I cannot create, I do not understand.” -- Richard Feynman近段時(shí)間最火的方向無疑是基于文本用AI生成圖像,繼OpenAI在2021提出的文本轉(zhuǎn)圖像模型DALLE之后,越來越多的大公司卷入這個(gè)方向,如谷歌在今年相繼推出了Imagen和Parti。一些主流的文本轉(zhuǎn)圖像模型如DALL·E 2,stable-diffusion和Imagen采用了擴(kuò)散模型(Diffusion Model)作為圖像生成模型,這也引發(fā)了對擴(kuò)散模型的研究熱潮。相比GAN來說,擴(kuò)散模型訓(xùn)練更穩(wěn)定,而且能夠生成更多樣的樣本,OpenAI的論文Diffusion Models Beat GANs on Image Synthesis也證明了擴(kuò)散模型能夠超越GAN。簡單來說,擴(kuò)散模型包含兩個(gè)過程:前向擴(kuò)散過程和反向生成過程,前向擴(kuò)散過程是對一張圖像逐漸添加高斯噪音直至變成隨機(jī)噪音,而反向生成過程是去噪音過程,我們將從一個(gè)隨機(jī)噪音開始逐漸去噪音直至生成一張圖像,這也是我們要求解或者訓(xùn)練的部分。擴(kuò)散模型與其它主流生成模型的對比如下所示:目前所采用的擴(kuò)散模型大都是來自于2020年的工作DDPM: Denoising Diffusion Probabilistic Models,DDPM對之前的擴(kuò)散模型(具體見Deep Unsupervised Learning using Nonequilibrium Thermodynamics)進(jìn)行了簡化,并通過變分推斷(variational inference)來進(jìn)行建模,這主要是因?yàn)閿U(kuò)散模型也是一個(gè)隱變量模型(latent variable model),相比VAE這樣的隱變量模型,擴(kuò)散模型的隱變量是和原始數(shù)據(jù)是同維度的,而且推理過程(即擴(kuò)散過程)往往是固定的。這篇文章將基于DDPM詳細(xì)介紹擴(kuò)散模型的原理,并給出具體的代碼實(shí)現(xiàn)和分析。擴(kuò)散模型原理擴(kuò)散模型包括兩個(gè)過程:前向過程(forward process)和反向過程(reverse process),其中前向過程又稱為為擴(kuò)散過程(diffusion process),如下圖所示。無論是前向過程還是反向過程都是一個(gè)參數(shù)化的馬爾可夫鏈(Markov chain),其中反向過程可以用來生成數(shù)據(jù),這里我們將通過變分推斷來進(jìn)行建模和求解。擴(kuò)散過程擴(kuò)散過程是指的對數(shù)據(jù)逐漸增加高斯噪音直至數(shù)據(jù)變成隨機(jī)噪音的過程。對于原始數(shù)據(jù),總共包含步的擴(kuò)散過程的每一步都是對上一步得到的數(shù)據(jù)按如下方式增加高斯噪音:這里為每一步所采用的方差,它介于0~1之間。對于擴(kuò)散模型,我們往往稱不同step的方差設(shè)定為variance schedule或者noise schedule,通常情況下,越后面的step會采用更大的方差,即滿足。在一個(gè)設(shè)計(jì)好的variance schedule下,的如果擴(kuò)散步數(shù)足夠大,那么最終得到的就完全丟失了原始數(shù)據(jù)而變成了一個(gè)隨機(jī)噪音。擴(kuò)散過程的每一步都生成一個(gè)帶噪音的數(shù)據(jù),整個(gè)擴(kuò)散過程也就是一個(gè)馬爾卡夫鏈:另外要指出的是,擴(kuò)散過程往往是固定的,即采用一個(gè)預(yù)先定義好的variance schedule,比如DDPM就采用一個(gè)線性的variance schedule。擴(kuò)散過程的一個(gè)重要特性是我們可以直接基于原始數(shù)據(jù)來對任意步的進(jìn)行采樣:。這里定義和,通過重參數(shù)技巧(和VAE類似),那么有:上述推到過程利用了兩個(gè)方差不同的高斯分布和相加等于一個(gè)新的高斯分布。反重參數(shù)化后,我們得到:擴(kuò)散過程的這個(gè)特性很重要。首先,我們可以看到其實(shí)可以看成是原始數(shù)據(jù)和隨機(jī)噪音的線性組合,其中和為組合系數(shù),它們的平方和等于1,我們也可以稱兩者分別為signal_rate和noise_rate(見https://keras.io/examples/generative/ddim/#diffusion-schedule和Variational Diffusion Models)。更近一步地,我們可以基于而不是來定義noise schedule(見Improved Denoising Diffusion Probabilistic Models所設(shè)計(jì)的cosine schedule),因?yàn)檫@樣處理更直接,比如我們直接將設(shè)定為一個(gè)接近0的值,那么就可以保證最終得到的近似為一個(gè)隨機(jī)噪音。其次,后面的建模和分析過程將使用這個(gè)特性。反向過程擴(kuò)散過程是將數(shù)據(jù)噪音化,那么反向過程就是一個(gè)去噪的過程,如果我們知道反向過程的每一步的真實(shí)分布,那么從一個(gè)隨機(jī)噪音開始,逐漸去噪就能生成一個(gè)真實(shí)的樣本,所以反向過程也就是生成數(shù)據(jù)的過程。估計(jì)分布需要用到整個(gè)訓(xùn)練樣本,我們可以用神經(jīng)網(wǎng)絡(luò)來估計(jì)這些分布。這里,我們將反向過程也定義為一個(gè)馬爾卡夫鏈,只不過它是由一系列用神經(jīng)網(wǎng)絡(luò)參數(shù)化的高斯分布來組成:這里,而為參數(shù)化的高斯分布,它們的均值和方差由訓(xùn)練的網(wǎng)絡(luò)和給出。實(shí)際上,擴(kuò)散模型就是要得到這些訓(xùn)練好的網(wǎng)絡(luò),因?yàn)樗鼈儤?gòu)成了最終的生成模型。雖然分布是不可直接處理的,但是加上條件的后驗(yàn)分布卻是可處理的,這里有:下面我們來具體推導(dǎo)這個(gè)分布,首先根據(jù)貝葉斯公式,我們有:由于擴(kuò)散過程的馬爾卡夫鏈特性,我們知道分布(這里條件是多余的),而由前面得到的擴(kuò)散過程特性可知:所以,我們有:這里的是一個(gè)和無關(guān)的部分,所以省略。根據(jù)高斯分布的概率密度函數(shù)定義和上述結(jié)果(配平方),我們可以得到分布的均值和方差:可以看到方差是一個(gè)定量(擴(kuò)散過程參數(shù)固定),而均值是一個(gè)依賴和的函數(shù)。這個(gè)分布將會被用于推導(dǎo)擴(kuò)散模型的優(yōu)化目標(biāo)。優(yōu)化目標(biāo)上面介紹了擴(kuò)散模型的擴(kuò)散過程和反向過程,現(xiàn)在我們來從另外一個(gè)角度來看擴(kuò)散模型:如果我們把中間產(chǎn)生的變量看成隱變量的話,那么擴(kuò)散模型其實(shí)是包含個(gè)隱變量的隱變量模型(latent variable model),它可以看成是一個(gè)特殊的Hierarchical VAEs(見Understanding Diffusion Models: A Unified Perspective):相比VAE來說,擴(kuò)散模型的隱變量是和原始數(shù)據(jù)同維度的,而且encoder(即擴(kuò)散過程)是固定的。既然擴(kuò)散模型是隱變量模型,那么我們可以就可以基于變分推斷來得到variational lower bound(VLB,又稱ELBO)作為最大化優(yōu)化目標(biāo),這里有:這里最后一步是利用了Jensen"s inequality(不采用這個(gè)不等式的推導(dǎo)見博客What are Diffusion Models?),對于網(wǎng)絡(luò)訓(xùn)練來說,其訓(xùn)練目標(biāo)為VLB取負(fù):我們近一步對訓(xùn)練目標(biāo)進(jìn)行分解可得:可以看到最終的優(yōu)化目標(biāo)共包含項(xiàng),其中可以看成是原始數(shù)據(jù)重建,優(yōu)化的是負(fù)對數(shù)似然,可以用估計(jì)的來構(gòu)建一個(gè)離散化的decoder來計(jì)算(見DDPM論文3.3部分);而計(jì)算的是最后得到的噪音的分布和先驗(yàn)分布的KL散度,這個(gè)KL散度沒有訓(xùn)練參數(shù),近似為0,因?yàn)橄闰?yàn)而擴(kuò)散過程最后得到的隨機(jī)噪音也近似為;而則是計(jì)算的是估計(jì)分布和真實(shí)后驗(yàn)分布的KL散度,這里希望我們估計(jì)的去噪過程和依賴真實(shí)數(shù)據(jù)的去噪過程近似一致:之所以前面我們將定義為一個(gè)用網(wǎng)絡(luò)參數(shù)化的高斯分布,是因?yàn)橐ヅ涞暮篁?yàn)分布也是一個(gè)高斯分布。對于訓(xùn)練目標(biāo)和來說,都是希望得到訓(xùn)練好的網(wǎng)絡(luò)和(對于,)。DDPM對做了近一步簡化,采用固定的方差:,這里的可以設(shè)定為或者(這其實(shí)是兩個(gè)極端,分別是上限和下限,也可以采用可訓(xùn)練的方差,見論文Improved Denoising Diffusion Probabilistic Models和Analytic-DPM: an Analytic Estimate of the Optimal Reverse Variance in Diffusion Probabilistic Models)。這里假定,那么:對于兩個(gè)高斯分布的KL散度,其計(jì)算公式為(具體推導(dǎo)見生成模型之VAE):那么就有:那么優(yōu)化目標(biāo)即為:從上述公式來看,我們是希望網(wǎng)絡(luò)學(xué)習(xí)到的均值和后驗(yàn)分布的均值一致。不過DDPM發(fā)現(xiàn)預(yù)測均值并不是最好的選擇。根據(jù)前面得到的擴(kuò)散過程的特性,我們有:將這個(gè)公式帶入上述優(yōu)化目標(biāo),可以得到:近一步地,我們對也進(jìn)行重參數(shù)化,變成:這里的是一個(gè)基于神經(jīng)網(wǎng)絡(luò)的擬合函數(shù),這意味著我們由原來的預(yù)測均值而換成預(yù)測噪音。我們將上述等式帶入優(yōu)化目標(biāo),可以得到:DDPM近一步對上述目標(biāo)進(jìn)行了簡化,即去掉了權(quán)重系數(shù),變成了:這里的在[1, T]范圍內(nèi)取值(如前所述,其中取1時(shí)對應(yīng))。由于去掉了不同的權(quán)重系數(shù),所以這個(gè)簡化的目標(biāo)其實(shí)是VLB優(yōu)化目標(biāo)進(jìn)行了reweight。從DDPM的對比實(shí)驗(yàn)結(jié)果來看,預(yù)測噪音比預(yù)測均值效果要好,采用簡化版本的優(yōu)化目標(biāo)比VLB目標(biāo)效果要好:雖然擴(kuò)散模型背后的推導(dǎo)比較復(fù)雜,但是我們最終得到的優(yōu)化目標(biāo)非常簡單,就是讓網(wǎng)絡(luò)預(yù)測的噪音和真實(shí)的噪音一致。DDPM的訓(xùn)練過程也非常簡單,如下圖所示:隨機(jī)選擇一個(gè)訓(xùn)練樣本->從1-T隨機(jī)抽樣一個(gè)t->隨機(jī)產(chǎn)生噪音-計(jì)算當(dāng)前所產(chǎn)生的帶噪音數(shù)據(jù)(紅色框所示)->輸入網(wǎng)絡(luò)預(yù)測噪音->計(jì)算產(chǎn)生的噪音和預(yù)測的噪音的L2損失->計(jì)算梯度并更新網(wǎng)絡(luò)。一旦訓(xùn)練完成,其采樣過程也非常簡單,如上所示:我們從一個(gè)隨機(jī)噪音開始,并用訓(xùn)練好的網(wǎng)絡(luò)預(yù)測噪音,然后計(jì)算條件分布的均值(紅色框部分),然后用均值加標(biāo)準(zhǔn)差乘以一個(gè)隨機(jī)噪音,直至t=0完成新樣本的生成(最后一步不加噪音)。不過實(shí)際的代碼實(shí)現(xiàn)和上述過程略有區(qū)別(見https://github.com/hojonathanho/diffusion/issues/5:先基于預(yù)測的噪音生成,并進(jìn)行了clip處理(范圍[-1, 1],原始數(shù)據(jù)歸一化到這個(gè)范圍),然后再計(jì)算均值。我個(gè)人的理解這應(yīng)該算是一種約束,既然模型預(yù)測的是噪音,那么我們也希望用預(yù)測噪音重構(gòu)處理的原始數(shù)據(jù)也應(yīng)該滿足范圍要求。模型設(shè)計(jì)前面我們介紹了擴(kuò)散模型的原理以及優(yōu)化目標(biāo),那么擴(kuò)散模型的核心就在于訓(xùn)練噪音預(yù)測模型,由于噪音和原始數(shù)據(jù)是同維度的,所以我們可以選擇采用AutoEncoder架構(gòu)來作為噪音預(yù)測模型。DDPM所采用的模型是一個(gè)基于residual block和attention block的U-Net模型。如下所示:U-Net屬于encoder-decoder架構(gòu),其中encoder分成不同的stages,每個(gè)stage都包含下采樣模塊來降低特征的空間大?。℉和W),然后decoder和encoder相反,是將encoder壓縮的特征逐漸恢復(fù)。U-Net在decoder模塊中還引入了skip connection,即concat了encoder中間得到的同維度特征,這有利于網(wǎng)絡(luò)優(yōu)化。DDPM所采用的U-Net每個(gè)stage包含2個(gè)residual block,而且部分stage還加入了self-attention模塊增加網(wǎng)絡(luò)的全局建模能力。另外,擴(kuò)散模型其實(shí)需要的是個(gè)噪音預(yù)測模型,實(shí)際處理時(shí),我們可以增加一個(gè)time embedding(類似transformer中的position embedding)來將timestep編碼到網(wǎng)絡(luò)中,從而只需要訓(xùn)練一個(gè)共享的U-Net模型。具體地,DDPM在各個(gè)residual block都引入了time embedding,如上圖所示。代碼實(shí)現(xiàn)最后,我們基于PyTorch框架給出DDPM的具體實(shí)現(xiàn),這里主要參考了三套代碼實(shí)現(xiàn):GitHub - hojonathanho/diffusion: Denoising Diffusion Probabilistic Models(官方TensorFlow實(shí)現(xiàn))GitHub - openai/improved-diffusion: Release for Improved Denoising Diffusion Probabilistic Models (OpenAI基于PyTorch實(shí)現(xiàn)的DDPM+)GitHub - lucidrains/denoising-diffusion-pytorch: Implementation of Denoising Diffusion Probabilistic Model in Pytorch首先,是time embeding,這里是采用Attention Is All You Need中所設(shè)計(jì)的sinusoidal position embedding,只不過是用來編碼timestep:#usesinusoidalpositionembeddingtoencodetimestep(https://arxiv.org/abs/1706.03762)deftimestep_embedding(timesteps,dim,max_period=10000):"""Createsinusoidaltimestepembeddings.:paramtimesteps:a1-DTensorofNindices,oneperbatchelement.Thesemaybefractional.:paramdim:thedimensionoftheoutput.:parammax_period:controlstheminimumfrequencyoftheembeddings.:return:an[Nxdim]Tensorofpositionalembeddings."""half=dim//2freqs=torch.exp(-math.log(max_period)*torch.arange(start=0,end=half,dtype=torch.float32)/half).to(device=timesteps.device)args=timesteps[:,None].float()*freqs[None]embedding=torch.cat([torch.cos(args),torch.sin(args)],dim=-1)ifdim%2:embedding=torch.cat([embedding,torch.zeros_like(embedding[:,:1])],dim=-1)returnembedding由于只有residual block才引入time embedding,所以可以定義一些輔助模塊來自動(dòng)處理,如下所示:#defineTimestepEmbedSequentialtosupport`time_emb`asextrainputclassTimestepBlock(nn.Module):"""Anymodulewhereforward()takestimestepembeddingsasasecondargument."""@abstractmethoddefforward(self,x,emb):"""Applythemoduleto`x`given`emb`timestepembeddings."""classTimestepEmbedSequential(nn.Sequential,TimestepBlock):"""Asequentialmodulethatpassestimestepembeddingstothechildrenthatsupportitasanextrainput."""defforward(self,x,emb):forlayerinself:ifisinstance(layer,TimestepBlock):x=layer(x,emb)else:x=layer(x)returnx這里所采用的U-Net采用GroupNorm進(jìn)行歸一化,所以這里也簡單定義了一個(gè)norm layer以方便使用:#useGNfornormlayerdefnorm_layer(channels):returnnn.GroupNorm(32,channels)U-Net的核心模塊是residual block,它包含兩個(gè)卷積層以及shortcut,同時(shí)也要引入time embedding,這里額外定義了一個(gè)linear層來將time embedding變換為和特征維度一致,第一conv之后通過加上time embedding來編碼time:#ResidualblockclassResidualBlock(TimestepBlock):def__init__(self,in_channels,out_channels,time_channels,dropout):super().__init__()self.conv1=nn.Sequential(norm_layer(in_channels),nn.SiLU(),nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1))#pojectionfortimestepembeddingself.time_emb=nn.Sequential(nn.SiLU(),nn.Linear(time_channels,out_channels))self.conv2=nn.Sequential(norm_layer(out_channels),nn.SiLU(),nn.Dropout(p=dropout),nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1))ifin_channels!=out_channels:self.shortcut=nn.Conv2d(in_channels,out_channels,kernel_size=1)else:self.shortcut=nn.Identity()defforward(self,x,t):"""`x`hasshape`[batch_size,in_dim,height,width]``t`hasshape`[batch_size,time_dim]`"""h=self.conv1(x)#Addtimestepembeddingsh+=self.time_emb(t)[:,:,None,None]h=self.conv2(h)returnh+self.shortcut(x)這里還在部分residual block引入了attention,這里的attention和transformer的self-attention是一致的:#AttentionblockwithshortcutclassAttentionBlock(nn.Module):def__init__(self,channels,num_heads=1):super().__init__()self.num_heads=num_headsassertchannels%num_heads==0self.norm=norm_layer(channels)self.qkv=nn.Conv2d(channels,channels*3,kernel_size=1,bias=False)self.proj=nn.Conv2d(channels,channels,kernel_size=1)defforward(self,x):B,C,H,W=x.shapeqkv=self.qkv(self.norm(x))q,k,v=qkv.reshape(B*self.num_heads,-1,H*W).chunk(3,dim=1)scale=1./math.sqrt(math.sqrt(C//self.num_heads))attn=torch.einsum("bct,bcs->bts",q*scale,k*scale)attn=attn.softmax(dim=-1)h=torch.einsum("bts,bcs->bct",attn,v)h=h.reshape(B,-1,H,W)h=self.proj(h)returnh+x對于上采樣模塊和下采樣模塊,其分別可以采用插值和stride=2的conv或者pooling來實(shí)現(xiàn):#upsampleclassUpsample(nn.Module):def__init__(self,channels,use_conv):super().__init__()self.use_conv=use_convifuse_conv:self.conv=nn.Conv2d(channels,channels,kernel_size=3,padding=1)defforward(self,x):x=F.interpolate(x,scale_factor=2,mode="nearest")ifself.use_conv:x=self.conv(x)returnx#downsampleclassDownsample(nn.Module):def__init__(self,channels,use_conv):super().__init__()self.use_conv=use_convifuse_conv:self.op=nn.Conv2d(channels,channels,kernel_size=3,stride=2,padding=1)else:self.op=nn.AvgPool2d(stride=2)defforward(self,x):returnself.op(x)上面我們實(shí)現(xiàn)了U-Net的所有組件,就可以進(jìn)行組合來實(shí)現(xiàn)U-Net了:#ThefullUNetmodelwithattentionandtimestepembeddingclassUNetModel(nn.Module):def__init__(self,in_channels=3,model_channels=128,out_channels=3,num_res_blocks=2,attention_resolutions=(8,16),dropout=0,channel_mult=(1,2,2,2),conv_resample=True,num_heads=4):super().__init__()self.in_channels=in_channelsself.model_channels=model_channelsself.out_channels=out_channelsself.num_res_blocks=num_res_blocksself.attention_resolutions=attention_resolutionsself.dropout=dropoutself.channel_mult=channel_multself.conv_resample=conv_resampleself.num_heads=num_heads#timeembeddingtime_embed_dim=model_channels*4self.time_embed=nn.Sequential(nn.Linear(model_channels,time_embed_dim),nn.SiLU(),nn.Linear(time_embed_dim,time_embed_dim),)#downblocksself.down_blocks=nn.ModuleList([TimestepEmbedSequential(nn.Conv2d(in_channels,model_channels,kernel_size=3,padding=1))])down_block_chans=[model_channels]ch=model_channelsds=1forlevel,multinenumerate(channel_mult):for_inrange(num_res_blocks):layers=[ResidualBlock(ch,mult*model_channels,time_embed_dim,dropout)]ch=mult*model_channelsifdsinattention_resolutions:layers.append(AttentionBlock(ch,num_heads=num_heads))self.down_blocks.append(TimestepEmbedSequential(*layers))down_block_chans.append(ch)iflevel!=len(channel_mult)-1:#don"tusedownsampleforthelaststageself.down_blocks.append(TimestepEmbedSequential(Downsample(ch,conv_resample)))down_block_chans.append(ch)ds*=2#middleblockself.middle_block=TimestepEmbedSequential(ResidualBlock(ch,ch,time_embed_dim,dropout),AttentionBlock(ch,num_heads=num_heads),ResidualBlock(ch,ch,time_embed_dim,dropout))#upblocksself.up_blocks=nn.ModuleList([])forlevel,multinlist(enumerate(channel_mult))[::-1]:foriinrange(num_res_blocks+1):layers=[ResidualBlock(ch+down_block_chans.pop(),model_channels*mult,time_embed_dim,dropout)]ch=model_channels*multifdsinattention_resolutions:layers.append(AttentionBlock(ch,num_heads=num_heads))iflevelandi==num_res_blocks:layers.append(Upsample(ch,conv_resample))ds//=2self.up_blocks.append(TimestepEmbedSequential(*layers))self.out=nn.Sequential(norm_layer(ch),nn.SiLU(),nn.Conv2d(model_channels,out_channels,kernel_size=3,padding=1),)defforward(self,x,timesteps):"""Applythemodeltoaninputbatch.:paramx:an[NxCxHxW]Tensorofinputs.:paramtimesteps:a1-Dbatchoftimesteps.:return:an[NxCx...]Tensorofoutputs."""hs=[]#timestepembeddingemb=self.time_embed(timestep_embedding(timesteps,self.model_channels))#downstageh=xformoduleinself.down_blocks:h=module(h,emb)hs.append(h)#middlestageh=self.middle_block(h,emb)#upstageformoduleinself.up_blocks:cat_in=torch.cat([h,hs.pop()],dim=1)h=module(cat_in,emb)returnself.out(h)對于擴(kuò)散過程,其主要的參數(shù)就是timesteps和noise schedule,DDPM采用范圍為[0.0001, 0.02]的線性noise schedule,其默認(rèn)采用的總擴(kuò)散步數(shù)為1000。#betascheduledeflinear_beta_schedule(timesteps):scale=1000/timestepsbeta_start=scale*0.0001beta_end=scale*0.02returntorch.linspace(beta_start,beta_end,timesteps,dtype=torch.float64)我們定義個(gè)擴(kuò)散模型,它主要要提前根據(jù)設(shè)計(jì)的noise schedule來計(jì)算一些系數(shù),并實(shí)現(xiàn)一些擴(kuò)散過程和生成過程:classGaussianDiffusion:def__init__(self,timesteps=1000,beta_schedule="linear"):self.timesteps=timestepsifbeta_schedule=="linear":betas=linear_beta_schedule(timesteps)elifbeta_schedule=="cosine":betas=cosine_beta_schedule(timesteps)else:raiseValueError(f"unknownbetaschedule{beta_schedule}")self.betas=betasself.alphas=1.-self.betasself.alphas_cumprod=torch.cumprod(self.alphas,axis=0)self.alphas_cumprod_prev=F.pad(self.alphas_cumprod[:-1],(1,0),value=1.)#calculationsfordiffusionq(x_t|x_{t-1})andothersself.sqrt_alphas_cumprod=torch.sqrt(self.alphas_cumprod)self.sqrt_one_minus_alphas_cumprod=torch.sqrt(1.0-self.alphas_cumprod)self.log_one_minus_alphas_cumprod=torch.log(1.0-self.alphas_cumprod)self.sqrt_recip_alphas_cumprod=torch.sqrt(1.0/self.alphas_cumprod)self.sqrt_recipm1_alphas_cumprod=torch.sqrt(1.0/self.alphas_cumprod-1)#calculationsforposteriorq(x_{t-1}|x_t,x_0)self.posterior_variance=(self.betas*(1.0-self.alphas_cumprod_prev)/(1.0-self.alphas_cumprod))#below:logcalculationclippedbecausetheposteriorvarianceis0atthebeginning#ofthediffusionchainself.posterior_log_variance_clipped=torch.log(self.posterior_variance.clamp(min=1e-20))self.posterior_mean_coef1=(self.betas*torch.sqrt(self.alphas_cumprod_prev)/(1.0-self.alphas_cumprod))self.posterior_mean_coef2=((1.0-self.alphas_cumprod_prev)*torch.sqrt(self.alphas)/(1.0-self.alphas_cumprod))#gettheparamofgiventimesteptdef_extract(self,a,t,x_shape):batch_size=t.shape[0]out=a.to(t.device).gather(0,t).float()out=out.reshape(batch_size,*((1,)*(len(x_shape)-1)))returnout#forwarddiffusion(usingtheniceproperty):q(x_t|x_0)defq_sample(self,x_start,t,noise=None):ifnoiseisNone:noise=torch.randn_like(x_start)sqrt_alphas_cumprod_t=self._extract(self.sqrt_alphas_cumprod,t,x_start.shape)sqrt_one_minus_alphas_cumprod_t=self._extract(self.sqrt_one_minus_alphas_cumprod,t,x_start.shape)returnsqrt_alphas_cumprod_t*x_start+sqrt_one_minus_alphas_cumprod_t*noise#Getthemeanandvarianceofq(x_t|x_0).defq_mean_variance(self,x_start,t):mean=self._extract(self.sqrt_alphas_cumprod,t,x_start.shape)*x_startvariance=self._extract(1.0-self.alphas_cumprod,t,x_start.shape)log_variance=self._extract(self.log_one_minus_alphas_cumprod,t,x_start.shape)returnmean,variance,log_variance#Computethemeanandvarianceofthediffusionposterior:q(x_{t-1}|x_t,x_0)defq_posterior_mean_variance(self,x_start,x_t,t):posterior_mean=(self._extract(self.posterior_mean_coef1,t,x_t.shape)*x_start+self._extract(self.posterior_mean_coef2,t,x_t.shape)*x_t)posterior_variance=self._extract(self.posterior_variance,t,x_t.shape)posterior_log_variance_clipped=self._extract(self.posterior_log_variance_clipped,t,x_t.shape)returnposterior_mean,posterior_variance,posterior_log_variance_clipped#computex_0fromx_tandprednoise:thereverseof`q_sample`defpredict_start_from_noise(self,x_t,t,noise):return(self._extract(self.sqrt_recip_alphas_cumprod,t,x_t.shape)*x_t-self._extract(self.sqrt_recipm1_alphas_cumprod,t,x_t.shape)*noise)#computepredictedmeanandvarianceofp(x_{t-1}|x_t)defp_mean_variance(self,model,x_t,t,clip_denoised=True):#predictnoiseusingmodelpred_noise=model(x_t,t)#getthepredictedx_0:differentfromthealgorithm2inthepaperx_recon=self.predict_start_from_noise(x_t,t,pred_noise)ifclip_denoised:x_recon=torch.clamp(x_recon,min=-1.,max=1.)model_mean,posterior_variance,posterior_log_variance=\self.q_posterior_mean_variance(x_recon,x_t,t)returnmodel_mean,posterior_variance,posterior_log_variance#denoise_step:samplex_{t-1}fromx_tandpred_noise@torch.no_grad()defp_sample(self,model,x_t,t,clip_denoised=True):#predictmeanandvariancemodel_mean,_,model_log_variance=self.p_mean_variance(model,x_t,t,clip_denoised=clip_denoised)noise=torch.randn_like(x_t)#nonoisewhent==0nonzero_mask=((t!=0).float().view(-1,*([1]*(len(x_t.shape)-1))))#computex_{t-1}pred_img=model_mean+nonzero_mask*(0.5*model_log_variance).exp()*noisereturnpred_img#denoise:reversediffusion@torch.no_grad()defp_sample_loop(self,model,shape):batch_size=shape[0]device=next(model.parameters()).device#startfrompurenoise(foreachexampleinthebatch)img=torch.randn(shape,device=device)imgs=[]foriintqdm(reversed(range(0,timesteps)),desc="samplinglooptimestep",total=timesteps):img=self.p_sample(model,img,torch.full((batch_size,),i,device=device,dtype=torch.long))imgs.append(img.cpu().numpy())returnimgs#samplenewimages@torch.no_grad()defsample(self,model,image_size,batch_size=8,channels=3):returnself.p_sample_loop(model,shape=(batch_size,channels,image_size,image_size))#computetrainlossesdeftrain_losses(self,model,x_start,t):#generaterandomnoisenoise=torch.randn_like(x_start)#getx_tx_noisy=self.q_sample(x_start,t,noise=noise)predicted_noise=model(x_noisy,t)loss=F.mse_loss(noise,predicted_noise)returnloss其中幾個(gè)主要的函數(shù)總結(jié)如下:q_sample:實(shí)現(xiàn)的從到擴(kuò)散過程;q_posterior_mean_variance:實(shí)現(xiàn)的是后驗(yàn)分布的均值和方差的計(jì)算公式;predict_start_from_noise:q_sample的逆過程,根據(jù)預(yù)測的噪音來生成;p_mean_variance:根據(jù)預(yù)測的噪音來計(jì)算的均值和方差;p_sample:單個(gè)去噪step;p_sample_loop:整個(gè)去噪音過程,即生成過程。擴(kuò)散模型的訓(xùn)練過程非常簡單,如下所示:#trainepochs=10forepochinrange(epochs):forstep,(images,labels)inenumerate(train_loader):optimizer.zero_grad()batch_size=images.shape[0]images=images.to(device)#sampletuniformallyforeveryexampleinthebatcht=torch.randint(0,timesteps,(batch_size,),device=device).long()loss=gaussian_diffusion.train_losses(model,images,t)ifstep%200==0:print("Loss:",loss.item())loss.backward()optimizer.step()這里我們以mnist數(shù)據(jù)簡單實(shí)現(xiàn)了一個(gè)mnist-demo,下面是一些生成的樣本:對生成過程進(jìn)行采樣,如下所示展示了如何從一個(gè)隨機(jī)噪音生層一個(gè)手寫字體圖像:另外這里也提供了CIFAR10數(shù)據(jù)集的demo:ddpm_cifar10,不過只訓(xùn)練了200epochs,生成的圖像只是初見成效。小結(jié)相比VAE和GAN,擴(kuò)散模型的理論更復(fù)雜一些,不過其優(yōu)化目標(biāo)和具體實(shí)現(xiàn)卻并不復(fù)雜,這其實(shí)也讓人感嘆:一堆復(fù)雜的數(shù)據(jù)推導(dǎo),最終卻得到了一個(gè)簡單的結(jié)論。要深入理解擴(kuò)散模型,DDPM只是起點(diǎn),后面還有比較多的改進(jìn)工作,比如加速采樣的DDIM以及DDPM的改進(jìn)版本DDPM+和DDPM++。注:本人水平有限,如有謬誤,歡迎討論交流。參考Denoising Diffusion Probabilistic ModelsUnderstanding Diffusion Models: A Unified Perspectivehttps://spaces.ac.cn/archives/9119https://keras.io/examples/generative/ddim/What are Diffusion Models?https://cvpr2022-tutorial-diffusion-models.github.io/https://github.com/openai/improved-diffusionhttps://huggingface.co/blog/annotated-diffusionhttps://github.com/lucidrains/denoising-diffusion-pytorchhttps://github.com/hojonathanho/diffusion 關(guān)鍵詞: 擴(kuò)散過程 原始數(shù)據(jù) 高斯分布 相關(guān)閱讀 世界熱推薦:今晚7:00直播丨下一個(gè)突破... 今晚19:00,Cocos視頻號直播馬上點(diǎn)擊【預(yù)約】啦↓↓↓在運(yùn)營了三年... NFT周刊|Magic Eden宣布支持Polygon網(wǎng)... Block-986在NFT這樣的市場,每周都會有相當(dāng)多項(xiàng)目起起伏伏。在過去... 環(huán)球今亮點(diǎn)!頭條觀察 | DeFi的興衰與... 在比特幣得到機(jī)構(gòu)關(guān)注之后,許多財(cái)務(wù)專家預(yù)測世界將因?yàn)榧用茇泿诺?.. 重新審視合作,體育Crypto的可靠關(guān)系才能雙贏 Block-987即使在體育Crypto領(lǐng)域,人們的目光仍然集中在FTX上。隨著... 簡訊:前端單元測試,更進(jìn)一步 前端測試@2022如果從2014年Jest的第一個(gè)版本發(fā)布開始計(jì)算,前端開發(fā)... 焦點(diǎn)熱訊:劉強(qiáng)東這波操作秀 近日,劉強(qiáng)東發(fā)布京東全員信,信中提到:自2023年1月1日起,逐步為...