SciPy中的截断多变量正态分布?

lf3rwulv  于 8个月前  发布在  其他
关注(0)|答案(5)|浏览(77)

我试图自动化一个过程,在某些时候需要从截断的多元正态分布中抽取样本。也就是说,它是一个正常的多元正态分布(即。高斯),但变量被约束到一个长方体。我给出的输入是完整的多元正态分布的均值和协方差,但我需要在我的盒子里有样本。
到目前为止,我只是在框外拒绝样本,并在必要时重新排序,但我开始发现我的过程有时会给我(a)大的协方差和(b)接近边缘的均值。这两件事合在一起影响了我的系统的速度。
所以我想做的是首先正确地对分布进行采样。谷歌搜索只导致了这个讨论或scipy.stats中的truncnorm发行版。前者是不确定的,后者似乎是一个变量。是否存在任何原生的多元截断正态分布?这会比拒绝样本更好吗?或者我应该做些更聪明的事情?
我将开始研究我自己的解决方案,这将是将未截断的高斯旋转到它的主轴(使用SVD分解或其他方法),使用截断高斯的乘积对分布进行采样,然后将该样本旋转回来,并在必要时拒绝/重新采样。如果截尾采样更有效,我认为这应该更快地采样所需的分布。

balp4ylt

balp4ylt1#

因此,根据the Wikipedia article,对多元截断正态分布(MTND)进行采样更加困难。我最终采取了一种相对简单的方法,并使用MCMC采样器放松了对MTND的初步猜测,如下所示。
我使用emcee来做MCMC工作。我发现这个软件包非常容易使用。它只需要一个函数来返回所需分布的对数概率。所以我定义了这个函数

from numpy.linalg import inv

def lnprob_trunc_norm(x, mean, bounds, C):
    if np.any(x < bounds[:,0]) or np.any(x > bounds[:,1]):
        return -np.inf
    else:
        return -0.5*(x-mean).dot(inv(C)).dot(x-mean)

这里,C是多元正态分布的协方差矩阵。然后,您可以运行类似于

S = emcee.EnsembleSampler(Nwalkers, Ndim, lnprob_trunc_norm, args = (mean, bounds, C))

pos, prob, state = S.run_mcmc(pos, Nsteps)

对于给定的meanboundsC。你需要对步行者的位置pos进行初始猜测,这可能是平均值周围的一个球,

pos = emcee.utils.sample_ball(mean, np.sqrt(np.diag(C)), size=Nwalkers)

或者从未截断的多元正态中采样,

pos = numpy.random.multivariate_normal(mean, C, size=Nwalkers)

我个人首先做了几千步的样本丢弃,因为它很快,然后迫使剩余的离群值回到边界内,然后运行MCMC采样。
收敛的步骤由您决定。
还请注意,emcee通过将参数threads=Nthreads添加到EnsembleSampler初始化来轻松支持基本并行化。所以你可以速战速决。

fkaflof6

fkaflof62#

我重新实现了一个算法,它不依赖于MCMC,而是从截断的多元正态分布中创建独立同分布(iid)样本。拥有iid样本可能非常有用!我过去也使用过Warrick在回答中描述的emcee,但是为了收敛,所需的样本数量在更高的维度上爆炸,这对我的用例来说是不切实际的。
该算法由Botev (2016)引入,并使用基于极小极大指数倾斜的接受-拒绝算法。它是originally implemented in MATLAB,但与在Python中使用MATLAB引擎运行它相比,为Python重新实现它显着提高了性能。它在更高的维度上也工作得很好,速度很快。
该代码可从以下网址获得:https://github.com/brunzema/truncated-mvn-sampler

示例:

d = 10  # dimensions

# random mu and cov
mu = np.random.rand(d)
cov = 0.5 - np.random.rand(d ** 2).reshape((d, d))
cov = np.triu(cov)
cov += cov.T - np.diag(cov.diagonal())
cov = np.dot(cov, cov)

# constraints
lb = np.zeros_like(mu) - 1
ub = np.ones_like(mu) * np.inf

# create truncated normal and sample from it
n_samples = 100000
tmvn = TruncatedMVN(mu, cov, lb, ub)
samples = tmvn.sample(n_samples)

绘制第一个尺寸将导致:

参考:

Botev,Z.I.,(2016),线性约束下的正态律:模拟和估计通过极小极大倾斜,杂志的皇家统计学会系列B,79,第1期,第。125-148

wpx232ag

wpx232ag3#

模拟截断的多元正态可能很棘手,通常涉及MCMC的一些条件采样。
我的简短回答是,你可以使用我的代码(https://github.com/ralphma1203/trun_mvnt)!!!它实现了从x1c 0d1x开始的Gibbs采样器算法,该算法可以处理

形式的一般线性约束,即使您有非满秩D和比维数更多的约束。

import numpy as np
from trun_mvnt import rtmvn, rtmvt

########## Traditional problem, probably what you need... ##########
##### lower < X < upper #####
# So D = identity matrix

D = np.diag(np.ones(4))
lower = np.array([-1,-2,-3,-4])
upper = -lower
Mean = np.zeros(4)
Sigma = np.diag([1,2,3,4])

n = 10 # want 500 final sample
burn = 100 # burn-in first 100 iterates
thin = 1 # thinning for Gibbs

random_sample = rtmvn(n, Mean, Sigma, D, lower, upper, burn, thin) 
# Numpy array n-by-p as result!
random_sample

########## Non-full rank problem (more constraints than dimension) ##########
Mean = np.array([0,0])
Sigma = np.array([1, 0.5, 0.5, 1]).reshape((2,2)) # bivariate normal

D = np.array([1,0,0,1,1,-1]).reshape((3,2)) # non-full rank problem
lower = np.array([-2,-1,-2])
upper = np.array([2,3,5])

n = 500 # want 500 final sample
burn = 100 # burn-in first 100 iterates
thin = 1 # thinning for Gibbs

random_sample = rtmvn(n, Mean, Sigma, D, lower, upper, burn, thin) # Numpy array n-by-p as result!
t30tvxxf

t30tvxxf4#

我编写了一个脚本来测量到目前为止提供的解决方案的运行时间。要绘制50维分布的100个样本,运行时间为

  1. Botev 2016 implementation 0.029579秒
  2. Li and Ghosh 2015 implementation 2.597150秒
    1.使用emcee 217.914969 s
    为了比较,使用Cholesky分解的没有任何截断的采样花费0.000201秒。
    对于这个特定的场景,Botev 2016 implementation是从截断的多元正态分布中采样的最快方法。emcee方法要慢得多,但也许可以调整它以获得更好的性能。
    输出是在6核x86机器上用以下代码生成的。
import emcee
from minimax_tilting_sampler import TruncatedMVN
from trun_mvnt import rtmvn
import numpy as np
from scipy import linalg
from multiprocessing import Pool

import time

def log_prob_mvtnd(x, mean, lb, ub, vcm):
    if np.any(np.less(x,lb)) or np.any(np.greater(x,ub)):
        return -np.inf
    else:
        diff = x - mean
        return -0.5 * np.dot(diff, np.linalg.solve(vcm, diff))

def tmvnd_emcee(vcm,lb,ub,n,mean=None):
    
    ndim = vcm.shape[0]
    
    if not mean:
        mean = np.zeros(ndim)
    
    with Pool() as pool:
        sampler = emcee.EnsembleSampler(nwalkers=n,
                                        ndim=ndim,
                                        log_prob_fn=log_prob_mvtnd,
                                        args = (mean, lb,ub, vcm),
                                        pool=pool
                                        )
    
        p0 = np.random.multivariate_normal(mean, vcm, size=n)
        # p0 = emcee.utils.sample_ball(mean, np.sqrt(np.diag(vcm)), size=n)
        
        state = sampler.run_mcmc(p0, 10)
        sampler.reset()
        sampler.run_mcmc(state, 100);

    return sampler.get_last_sample().coords

def mvnd_chol(vcm,n):
    gauss_samples = np.random.randn(vcm.shape[0],n)
    R = linalg.cholesky(vcm, lower=True)
    return np.dot(R,gauss_samples)
    
def tmvnd_botev16(vcm,lb,ub,n):
    tmvn = TruncatedMVN(np.zeros(vcm.shape[0]), vcm, lb, ub)
    return tmvn.sample(n)

def tmvnd_Li_Ghosh_15(vcm,lb,ub,n):
    burn = 100 # burn-in first 100 iterates
    thin = 1 # thinning for Gibbs

    D =np.eye((vcm.shape[0]))
    return rtmvn(n, np.zeros(vcm.shape[0]), vcm, D, lb, ub, burn, thin).T 

if __name__ == '__main__':
    
    ndim=50
    nsamples = 100
    sigma = np.exp(-1*np.linspace(0, 10, ndim))
    vcm = linalg.toeplitz(sigma)
    
    trunc_sigma = 2
    ub = np.sqrt(np.diag(vcm))*trunc_sigma
    lb = -ub
    
    
    samplers = { 'not truncated' : lambda vcm,lb,ub,n: mvnd_chol(vcm,n),
                 'emcee' : lambda vcm,lb,ub,n: tmvnd_emcee(vcm,lb,ub,n),
                 'Botev 2016': lambda vcm,lb,ub,n: tmvnd_botev16(vcm,lb,ub,n),
                 'Li and Ghosh 2015': lambda vcm,lb,ub,n: tmvnd_Li_Ghosh_15(vcm,lb,ub,n)}
    
    samples = []
    for name, sampler in samplers.items():
        start = time.time()
        samples.append(sampler(vcm,lb,ub,nsamples))
        end = time.time()
        print("%s %f s" %(name, end - start))
3pmvbmvn

3pmvbmvn5#

我想有点晚了,但为了记录在案,你可以用汉密尔顿蒙特卡罗。Matlab中有一个模块叫做HMC exact。翻译成Py应该不难。

相关问题