![]()
1 import numpy as np
2 import matplotlib.pyplot as plt
3 import random
4
5 class SA(object):
6
7 def __init__(self, interval, tab='min', T_max=10000, T_min=1, iterMax=1000, rate=0.95):
8 self.interval = interval # 给定状态空间 - 即待求解空间
9 self.T_max = T_max # 初始退火温度 - 温度上限
10 self.T_min = T_min # 截止退火温度 - 温度下限
11 self.iterMax = iterMax # 定温内部迭代次数
12 self.rate = rate # 退火降温速度
13 #############################################################
14 self.x_seed = random.uniform(interval[0], interval[1]) # 解空间内的种子
15 self.tab = tab.strip() # 求解最大值还是最小值的标签: 'min' - 最小值;'max' - 最大值
16 #############################################################
17 self.solve() # 完成主体的求解过程
18 self.display() # 数据可视化展示
19
20 def solve(self):
21 temp = 'deal_' + self.tab # 采用反射方法提取对应的函数
22 if hasattr(self, temp):
23 deal = getattr(self, temp)
24 else:
25 exit('>>>tab标签传参有误:"min"|"max"<<<')
26 x1 = self.x_seed
27 T = self.T_max
28 while T >= self.T_min:
29 for i in range(self.iterMax):
30 f1 = self.func(x1)
31 delta_x = random.random() * 2 - 1
32 if x1 + delta_x >= self.interval[0] and x1 + delta_x <= self.interval[1]: # 将随机解束缚在给定状态空间内
33 x2 = x1 + delta_x
34 else:
35 x2 = x1 - delta_x
36 f2 = self.func(x2)
37 delta_f = f2 - f1
38 x1 = deal(x1, x2, delta_f, T)
39 T *= self.rate
40 self.x_solu = x1 # 提取最终退火解
41
42 def func(self, x): # 状态产生函数 - 即待求解函数
43 value = np.sin(x**2) * (x**2 - 5*x)
44 return value
45
46 def p_min(self, delta, T): # 计算最小值时,容忍解的状态迁移概率
47 probability = np.exp(-delta/T)
48 return probability
49
50 def p_max(self, delta, T):
51 probability = np.exp(delta/T) # 计算最大值时,容忍解的状态迁移概率
52 return probability
53
54 def deal_min(self, x1, x2, delta, T):
55 if delta < 0: # 更优解
56 return x2
57 else: # 容忍解
58 P = self.p_min(delta, T)
59 if P > random.random(): return x2
60 else: return x1
61
62 def deal_max(self, x1, x2, delta, T):
63 if delta > 0: # 更优解
64 return x2
65 else: # 容忍解
66 P = self.p_max(delta, T)
67 if P > random.random(): return x2
68 else: return x1
69
70 def display(self):
71 print('seed: {}\nsolution: {}'.format(self.x_seed, self.x_solu))
72 plt.figure(figsize=(6, 4))
73 x = np.linspace(self.interval[0], self.interval[1], 300)
74 y = self.func(x)
75 plt.plot(x, y, 'g-', label='function')
76 plt.plot(self.x_seed, self.func(self.x_seed), 'bo', label='seed')
77 plt.plot(self.x_solu, self.func(self.x_solu), 'r*', label='solution')
78 plt.title('solution = {}'.format(self.x_solu))
79 plt.xlabel('x')
80 plt.ylabel('y')
81 plt.legend()
82 plt.savefig('SA.png', dpi=500)
83 plt.show()
84 plt.close()
85
86
87 if __name__ == '__main__':
88 SA([-5, 5], 'max')