0%

论文阅读-ISTA和LISTA

学习一下LISTA的思想

Learning Fast Approximations of Sparse Coding

主要阅读这篇论文,代码参考这篇文章

学习这篇文章的目的是为掌握一下在普通的算法中加入网络模型后的训练方法,而不是局限于ISTA算法本身的细节。

ISTA

ISTA目的是求解这样一个问题,对于一个n维向量,希望找到一个最佳的稀疏编码,且编码维度为m小于n,使得可以由线性表示,即

由此,可以构建带有L1正则化项的损失函数(L1正则化项能让输出趋向于稀疏),而ISTA算法本身也可以求解lasso回归问题,可以有效避免求解大矩阵的逆

所以目的就是最小化此损失函数即可,而ISTA的思想就是通过迭代的方式来求解,每次迭代的更新公式为,其中L是一个常数,要比的最大特征值大

只要把里面展开再合并同类项即可有

其中可以看成是ReLu函数左半部分由右半部分中心对称得到的。而这个阈值

通过这样一次次迭代,那么就能进行稀疏编码的求解。这个算法的大致框图和伪代码如下

**需要注意,ISTA算法仅仅求解了最佳的稀疏编码,是不能去求解

LISTA

LISTA则是增加了深度学习的ISTA算法,可以看成是很多个全连接块拼接在一起得到的。

而LISTA算法则可以看成是学习出所有的,然后通过这些参数来进行稀疏编码的求解。

由于是学习的算法,所以自然涉及到前向推断和反向传播的过程。

前向传播

前向传播与ISTA算法基本一致,基本上就是和之前一样的步骤,来求出优化解。

值得注意的是,这里面看起来也是在不断进行迭代,但是这个迭代的目的是求解出编码;模型的下降则是在反向传播中进行的。

可以这样理解,反向传播中的梯度下降是对模型进行调节,使得模型大体上能够符合训练集的分布;而前向传播中的下降则是对编码进行调节,求解的是参数对于稀疏编码的最优损失函数的解。

换言之,这个algorithm 3就是在模型的总体分布已经被反向传播固定的情况下,通过迭代的方法来求解出对应的输入下的最优编码。

还需要注意一点的就是,前向传播里计算的Z(t)、C(t)、B是同样会在反向传播中使用到

反向传播

LISTA的反向传播步可能是一个难点,我们先来看伪代码

这里面的Z(t)、C(t)、B是之前前向传播计算出来的。而这个算法中的并不是需要用到的值,而是类似C语言那用用引用实现的多返回的写法。

首先使第一步中是为了求解出当前模型前向推断值和标签的误差,就宛若神经网络中的反向传播步骤一样。

但是由于在计算编码的时候需要进行多次的迭代操作,所以这里也需要从后向前依次计算,每次迭代都可以看成是经过了一次全连接的自己算,但是更新权重的时候所有需要更新的梯度最终都会被作用到一个神经元上。

当然,这里其实也只是理论上的分析可能会困难一点,实际上由于pytorch的自动求导的机制,实际写起来并没有很困难。

下面来看一下这几个偏导在理论上是怎么求的。

已知前向传播的两个等式,分别记为等式1和等式2


对于而言,只需要对等式2进行两边对求偏导并变形即可。

而对于其他的梯度更新项而言,都是对于等式1中对应的位置求解偏导得到的。

此外,由于前向传播是从第0项(初始值)到第T项的前向计算,而反向传播的循环中仅仅求了从第T项到第1项的梯度,所以还需要额外求解一个第0项的更新,这样就得到了模型中的反向更新梯度。

代码编写

代码参考了这篇博客

ISTA

首先来编写一下ISTA算法,这里面的ISTA算法是用来作为LISTA算法的对比算法的,所以不需要编写反向传播的代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np


def ista(X, W_d, a, L, max_iter, eps):
'''
X是输入数据
W_d是字典编码数据
a是正则化项惩罚系数
L是Lipschitz常数
'''

def shrinkage(x, theta):
'''
闭包函数,软阈值
'''
return np.multiply(np.sign(x), np.maximum(np.abs(x) - theta, 0))

# 首先计算W_d^T*W_d的特征值,并保证Lipschitz常数大于最大特征值
eig, eig_vector = np.linalg.eig(W_d.T * W_d)
assert L > np.max(eig)
del eig, eig_vector

# 按照公式计算W_e
W_e = W_d.T / L

recon_errors = []

# 开始迭代计算出编码Z,设定一个种子
Z_old = np.zeros((W_d.shape[1], 1))

for i in range(max_iter):
temp = W_d * Z_old - X
Z_new = shrinkage(Z_old - W_e * temp, a / L)

# 如果两次迭代的差值小于eps,则停止迭代
if np.sum(np.abs(Z_new - Z_old)) <= eps:
break
Z_old = Z_new

# 计算重构误差
# 需要加上正则化项
recon_error = np.linalg.norm(
X - W_d * Z_new, 2) ** 2 + a * np.linalg.norm(Z_new, 1)
recon_errors.append(recon_error)

return Z_new, recon_errors

简单验证一下,可以看到稀疏重建的效果,分别展示两个数据的重建效果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from scipy.linalg import orth
import matplotlib.pyplot as plt
import numpy as np


def ista(X, W_d, a, L, max_iter, eps):
'''
X是输入数据
W_d是字典编码数据
a是正则化项惩罚系数
L是Lipschitz常数
'''

def shrinkage(x, theta):
'''
闭包函数,软阈值
'''
return np.multiply(np.sign(x), np.maximum(np.abs(x) - theta, 0))

# 首先计算W_d^T*W_d的特征值,并保证Lipschitz常数大于最大特征值
eig, eig_vector = np.linalg.eig(W_d.T * W_d)
assert L > np.max(eig)
del eig, eig_vector

# 按照公式计算W_e
W_e = W_d.T / L

recon_errors = []

# 开始迭代计算出编码Z,设定一个种子
Z_old = np.zeros((W_d.shape[1], 1))

for i in range(max_iter):
temp = W_d * Z_old - X
Z_new = shrinkage(Z_old - W_e * temp, a / L)

# 如果两次迭代的差值小于eps,则停止迭代
if np.sum(np.abs(Z_new - Z_old)) <= eps:
break
Z_old = Z_new

# 计算重构误差
# 需要加上正则化项
recon_error = np.linalg.norm(
X - W_d * Z_new, 2) ** 2+a*np.linalg.norm(Z_new, 1)
recon_errors.append(recon_error)
print(f"index {i}: {recon_error}")

return Z_new, recon_errors


# 表示输入数据维度、输出数据维度、稀疏度
m, n, k = 1000, 256, 5
# 仅仅重建两个数据
N = 2

# 构建字典W_d,随机构建一个
Psi = np.eye(m)
Phi = np.random.randn(n, m)
Phi = np.transpose(orth(np.transpose(Phi)))
W_d = np.dot(Phi, Psi)
print(W_d.shape)

# 生成稀疏信号Z和测量X
Z = np.zeros((N, m))
X = np.zeros((N, n))
for i in range(N):
index_k = np.random.choice(a=m, size=k, replace=False, p=None)
Z[i, index_k] = 5 * np.random.randn(k, 1).reshape([-1, ])
X[i] = np.dot(W_d, Z[i, :])


# ISTA算法,展示两个稀疏重建效果
Z_recon, recon_errors = ista(np.mat(X).T, np.mat(W_d), 0.1, 2, 1000, 0.002)
plt.subplot(2, 1, 1)
plt.plot(Z_recon.T[0].T, '--', label='ISTA')
plt.title('1')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(Z_recon.T[1].T, '--', label='ISTA')
plt.title('2')
plt.legend()
plt.show()

LISTA

然后看一下LISTA的算法,利用pytorch来编写,直接使用自动求导机制,不需要再额外计算雅可比乱七八糟的东西了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from scipy.linalg import orth
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class LISTA(nn.Module):
def __init__(self, n, m, W_e, max_iter, L, theta):
"""
n: 输入维度
m: 稀疏表示的维度
W_e: 字典
max_iter: 最大迭代次数
L: Lipschitz常数
theta: 阈值
"""

super(LISTA, self).__init__()
# 构建两个全连接层,分别为W_e和S
self._W = nn.Linear(in_features=n, out_features=m, bias=False)
self._S = nn.Linear(in_features=m, out_features=m,
bias=False)

self._S2 = nn.Linear(in_features=m, out_features=m,
bias=False)
self._S3 = nn.Linear(in_features=m, out_features=m,
bias=False)
# 创建阈值函数,阈值为theta
self.shrinkage = nn.Softshrink(theta)
self.theta = theta

self.max_iter = max_iter
self.A = W_e
self.L = L

def weights_init(self):
"""
按照伪代码来初始化S和W_e
"""
A = self.A.cpu().numpy()
L = self.L
S = torch.from_numpy(np.eye(A.shape[1]) - (1/L)*np.matmul(A.T, A))
S = S.float().to(device)
W = torch.from_numpy((1/L)*A.T)
W = W.float().to(device)

self._S.weight = nn.Parameter(S)
self._S2.weight = nn.Parameter(S)
self._S3.weight = nn.Parameter(S)
self._W.weight = nn.Parameter(W)

def forward(self, y):
"""
前向推断步,利用自动求导机制不需要再求解导数
"""
x = self.shrinkage(self._W(y))

if self.max_iter == 1:
return x

for iter in range(self.max_iter):
# 这是一个定长操作,且pytorch自动解决导数链,所以不需要额外保存信息
x = self.shrinkage(self._W(y) + self._S(x))
x = self.shrinkage(self._W(y) + self._S2(x))
x = self.shrinkage(self._W(y) + self._S3(x))

return x


def train_lista(Y, dictionary, a, L, max_iter=30):
"""
由于需要训练权重,所以还需要使用一个包装函数来训练网络
"""

n, m = dictionary.shape
n_samples = Y.shape[0]
batch_size = 32
steps_per_epoch = n_samples // batch_size

Y = torch.from_numpy(Y)
Y = Y.float().to(device)

W_d = torch.from_numpy(dictionary)
W_d = W_d.float().to(device)

# 构建网络
net = LISTA(n, m, W_d, max_iter=30, L=L, theta=a/L)
net = net.float().to(device)
net.weights_init()

# 一些训练超参数
learning_rate = 1e-2
criterion1 = nn.MSELoss()
criterion2 = nn.L1Loss()
all_zeros = torch.zeros(batch_size, m).to(device)
optimizer = torch.optim.SGD(
net.parameters(), lr=learning_rate, momentum=0.9)

loss_list = []
for epoch in range(100):
index_samples = np.random.choice(
a=n_samples, size=n_samples, replace=False, p=None)
Y_shuffle = Y[index_samples]
for step in range(steps_per_epoch):
Y_batch = Y_shuffle[step*batch_size:(step+1)*batch_size]
optimizer.zero_grad()

# 计算输出
X_h = net(Y_batch)
Y_h = torch.mm(X_h, W_d.T)

# 计算正则化loss
loss1 = criterion1(Y_batch.float(), Y_h.float())
loss2 = a * criterion2(X_h.float(), all_zeros.float())
loss = loss1 + loss2

# 自动求导更新模型
loss.backward()
optimizer.step()

with torch.no_grad():
loss_list.append(loss.detach().data)
print("epoch: {}, loss: {}".format(epoch, loss.detach().data))

return net, loss_list


# 表示输入数据维度、输出数据维度、稀疏度
m, n, k = 1000, 256, 5
# 表示训练样本数,因为要训练,所以多一点
N = 128

# 初始化字典
Psi = np.eye(m)
Phi = np.random.randn(n, m)
Phi = np.transpose(orth(np.transpose(Phi)))
W_d = np.dot(Phi, Psi)
print(W_d.shape)

# 生成数据
Z = np.zeros((N, m))
X = np.zeros((N, n))
for i in range(N):
index_k = np.random.choice(a=m, size=k, replace=False, p=None)
Z[i, index_k] = 5 * np.random.randn(k, 1).reshape([-1, ])
X[i] = np.dot(W_d, Z[i, :])

# 计算网络和loss
net, err_list = train_lista(X, W_d, 0.1, 2)

# 推断并画出第一个样本和对应的稀疏表示

X_h = net(torch.from_numpy(X[0:2]).float().to(device))
plt.subplot(2, 1, 1)
plt.plot(X[0])
plt.title("Original")
plt.subplot(2, 1, 2)
plt.plot(X_h.cpu().detach().numpy()[0], label="reconstruction")
plt.title("Reconstruction")
plt.legend()
plt.show()

这里搭建了具有3个block的网络来进行训练(也就是和之前出现的那个LISTA算法一样)

这里再详细说明一下生成数据的代码

1
2
3
4
5
6
7
8
9
# 生成数据
m, n, k = 1000, 256, 5

Z = np.zeros((N, m))
X = np.zeros((N, n))
for i in range(N):
index_k = np.random.choice(a=m, size=k, replace=False, p=None)
Z[i, index_k] = 5 * np.random.randn(k, 1).reshape([-1, ])
X[i] = np.dot(W_d, Z[i, :])

由于LISTA算法的目的是训练出参数,所以自然需要准备好训练数据,但是由于我们没有很好的数据集,所以这里就直接生成了一些数据。在生成数据时,是先创建好一个稀疏的Z(有k个元素为非0的稀疏向量),然后再通过一个随机生成字典W_d来计算出X,这样就能够保证X和Z是具有对应关系的数据集。

可以看一下对应的重建效果,通过学习出,能够在一定程度上更好的完成重建的工作

总结

使用LISTA算法可以从一个预定义的集合中学习到特征的分布,而使用pytorch,并不需要我们去负责繁琐的求导步骤,只需要按照前向传播和反向传播的步骤来写就可以了,pytorch的自动求导可以很好的帮助我们计算出对应的导数并完成更新操作。