当前位置: 首页 > news >正文

基于“动手学强化学习”的知识点(一):第 14 章 SAC 算法(gym版本 >= 0.26)

第 14 章 SAC 算法(gym版本 >= 0.26)

  • 摘要
  • SAC 算法(连续)
  • SAC 算法(离散)

摘要

本系列知识点讲解基于动手学强化学习中的内容进行详细的疑难点分析!具体内容请阅读动手学强化学习!


对应动手学强化学习——SAC 算法


SAC 算法(连续)

# -*- coding: utf-8 -*-import random
import gym
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt
import rl_utilsclass PolicyNetContinuous(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim, action_bound):super(PolicyNetContinuous, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)self.fc_std = torch.nn.Linear(hidden_dim, action_dim)'''作用:保存动作幅度的界限,便于后续对动作做缩放。数值例子:若 action_bound=2,最终动作将会在 [-2, 2] 范围内。'''self.action_bound = action_bounddef forward(self, x):x = F.relu(self.fc1(x))mu = self.fc_mu(x)std = F.softplus(self.fc_std(x))'''作用:使用上面计算得到的 mu 和 std 构造正态分布对象 dist。数值例子:- 这时构造的分布为 𝑁(0.8,0.474^2)。'''dist = Normal(mu, std)'''作用:从正态分布中采样,但采用“重参数化采样”(rsample),以便后续能对采样过程进行梯度反传。数值例子:- 例如,若采样时随机变量 ε 从标准正态分布中取到 0.3,则采样值为 0.8 + 0.474 * 0.3 ≈ 0.8 + 0.1422 = 0.9422。'''normal_sample = dist.rsample()  # rsample()是重参数化采样'''作用:计算刚采样值在原始正态分布下的对数概率密度。'''log_prob = dist.log_prob(normal_sample)'''作用:对采样的原始动作进行 tanh 激活,将其映射到 (-1, 1) 范围内,保证动作平滑且有界。数值例子:- 对于采样值 0.9422,torch.tanh(0.9422) ≈ 0.737。'''action = torch.tanh(normal_sample)# 计算tanh_normal分布的对数概率密度'''作用:由于经过了 tanh 非线性变换,原来的对数概率密度需要进行修正(Jacobian 修正项),这里用公式logp_action=logp_normal−log(1−tanh(action)^2+ϵ)注意:实际应用中,通常是对 normal_sample 进行修正,写法可能略有不同,但这里的目标是一致的——补偿 tanh 变换带来的概率密度变换。'''log_prob = log_prob - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)action = action * self.action_boundreturn action, log_probclass QValueNetContinuous(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNetContinuous, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)self.fc_out = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)x = F.relu(self.fc1(cat))x = F.relu(self.fc2(x))return self.fc_out(x)class SACContinuous:''' 处理连续动作的SAC算法 '''"""解释:- 定义一个名为 SACContinuous 的类,用来实现针对连续动作的 Soft Actor-Critic 算法。"""def __init__(self, state_dim, hidden_dim, action_dim, action_bound,actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,device):"""定义构造函数,接收一系列超参数,分别代表状态维度、隐藏层神经元个数、动作维度、动作界限、各网络的学习率、目标熵、软更新参数、折扣因子和设备。"""'''# 策略网络使用前面定义的 PolicyNetContinuous 构造函数生成策略网络(actor),并将该网络放到指定设备上(例如 CPU 或 GPU)。'''self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim, action_bound).to(device)  '''# 第一个Q网络创建第一个 Q 网络,用于评估(状态,动作)对的价值,同样放到指定设备。'''self.critic_1 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device) '''# 第二个Q网络创建第二个 Q 网络,与第一个结构相同,用于双重估计,帮助缓解过估计问题。'''self.critic_2 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)  '''# 第一个目标Q网络构造第一个目标 Q 网络,其结构与 critic_1 相同,用于计算目标值(TD目标),以便实现平滑更新。'''self.target_critic_1 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device) '''# 第二个目标Q网络构造第二个目标 Q 网络,其结构与 critic_2 相同,用于目标值计算。'''self.target_critic_2 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)  '''# 令目标Q网络的初始参数和Q网络一样将 critic_1 网络的所有参数复制到 target_critic_1 中,使二者初始时完全一致。将 critic_2 网络的所有参数复制到 target_critic_2 中,使二者初始时完全一致。'''self.target_critic_1.load_state_dict(self.critic_1.state_dict())self.target_critic_2.load_state_dict(self.critic_2.state_dict())'''使用 Adam 优化器为策略网络分配优化器,学习率为 actor_lr。'''self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)'''为 critic_1 分配 Adam 优化器,学习率为 critic_lr。为 critic_2 分配 Adam 优化器,学习率为 critic_lr。'''self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(), lr=critic_lr)self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(), lr=critic_lr)# 使用alpha的log值,可以使训练结果比较稳定'''创建一个标量张量,用于存储温度参数 alpha 的对数值。初始值设为 log(0.01) ≈ -4.6052。这样做有助于稳定训练,因为直接优化正数会带来数值不稳定问题。'''self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)'''设置该张量的 requires_grad 属性为 True,表示在反向传播时会计算关于 log_alpha 的梯度,从而能更新温度参数。'''self.log_alpha.requires_grad = True  # 可以对alpha求梯度'''为 log_alpha 创建一个 Adam 优化器,学习率为 alpha_lr。注意优化器接收的是一个包含 log_alpha 的列表。'''self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr)'''保存目标熵参数,这个值用于指导策略更新时保持足够的探索性。'''self.target_entropy = target_entropy  # 目标熵的大小'''保存折扣因子,用于计算未来奖励的折现值。'''self.gamma = gamma'''保存软更新系数 tau,用于更新目标网络的参数。'''self.tau = tau'''保存设备信息,便于后续将数据和模型都放到同一设备上。'''self.device = devicedef take_action(self, state):"""定义一个方法,根据当前状态输出一个动作(供环境交互时调用)。"""'''将传入的状态(例如一个列表或数组)转换为 PyTorch 张量,并在外面加一层列表以增加 batch 维度,然后将其放到指定设备上。state = [1,2,3,4]state = torch.tensor([state], dtype=torch.float).to("cuda")print(state) # tensor([[1., 2., 3., 4.]], device='cuda:0')state1 = [1,2,3,4]state1 = torch.tensor(state1, dtype=torch.float).to("cuda")print(state1) # tensor([1., 2., 3., 4.], device='cuda:0')state2 = [1,2,3,4]state2 = torch.tensor(state2, dtype=torch.float).unsqueeze(0).to("cuda")print(state2) # tensor([[1., 2., 3., 4.]], device='cuda:0')'''if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(self.device)'''解释:- 将状态输入 actor 网络,得到输出。由于 actor 的 forward 返回的是一个元组(动作、对数概率),这里取第一个元素(动作部分)。数值例子:- 假设 actor 返回 (tensor([[0.737]]), tensor([[-0.45]])),则 action = tensor([[0.737]]);再取 [0] 后得到单个样本的动作张量 tensor([0.737])。'''action = self.actor(state)[0]'''解释:- 将动作张量转换为 Python 标量,并放入列表后返回。数值例子:- action.item() 会返回 0.737,最终返回 [0.737]。这样可以适应环境要求动作为列表格式的情况。'''return [action.item()]def calc_target(self, rewards, next_states, dones):  # 计算目标Q值"""定义一个方法,利用下一时刻状态、奖励和 done 标志计算 TD 目标(目标 Q 值),用于 critic 网络的回归训练。"""'''对所有下一状态(通常是一个 batch),利用 actor 网络计算下一时刻动作和其对应的对数概率。数值例子:假设 next_states 有 2 个样本,每个样本状态为 3 维;actor 返回- next_actions = tensor([[1.2], [0.8]])- log_prob = tensor([[-0.5], [-0.6]])'''next_actions, log_prob = self.actor(next_states)'''计算熵项,实际上熵等于负的对数概率。数值例子:如果 log_prob = tensor([[-0.5], [-0.6]]),则 entropy = tensor([[0.5], [0.6]])。'''entropy = -log_prob'''使用目标网络1计算给定下一状态和对应动作的 Q 值。'''q1_value = self.target_critic_1(next_states, next_actions)'''使用目标网络2计算给定下一状态和对应动作的 Q 值。'''q2_value = self.target_critic_2(next_states, next_actions)'''解释:- 首先,取两个目标 Q 值的最小值(用来降低过估计风险);- 然后加上温度参数 alpha(由 self.log_alpha.exp() 得到)乘以熵项,这一项鼓励探索。数值例子:- 对第一样本:min(2.0, 2.5) = 2.0,且 self.log_alpha.exp() 计算为 exp(-4.6052) ≈ 0.01;熵为 0.5,则 next_value = 2.0 + 0.01 * 0.5 = 2.0 + 0.005 = 2.005。- 对第二样本:min(3.0, 3.5) = 3.0,熵为 0.6,则 next_value = 3.0 + 0.01 * 0.6 = 3.0 + 0.006 = 3.006。'''next_value = torch.min(q1_value, q2_value) + self.log_alpha.exp() * entropy'''计算 TD 目标:td_target = 𝑟 + 𝛾 × next_value × (1−done)当 done 为 1(表示回合结束)时,不再折扣未来奖励。'''td_target = rewards + self.gamma * next_value * (1 - dones)return td_targetdef soft_update(self, net, target_net):"""定义一个方法,用于对目标网络参数做软更新。传入当前网络和对应的目标网络。"""'''遍历目标网络和当前网络中对应的每一对参数(权重和偏置)。'''for param_target, param in zip(target_net.parameters(), net.parameters()):'''对每个参数做软更新:𝜃target←(1−𝜏)𝜃target+𝜏𝜃θtarget←(1−τ)θ target+τθ这可以平滑地将目标网络参数向当前网络参数靠拢。'''param_target.data.copy_(param_target.data * (1.0 - self.tau) + param.data * self.tau)def update(self, transition_dict):"""定义一个方法,根据从 replay buffer 中采样的转换数据(transition)更新 actor、critic 网络以及温度参数 alpha。"""'''将 transition_dict 中的状态数据转换为浮点型张量,并放到指定设备上。'''states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)'''同理,将动作数据转换为张量,并通过 view(-1, 1) 调整形状为 (batch_size, 1)(即每个动作为一个标量)。数值例子:若 transition_dict['actions'] = [1.0, 0.5],转换后形状为 (2, 1)。'''actions = torch.tensor(transition_dict['actions'], dtype=torch.float).view(-1, 1).to(self.device)'''将奖励数据转换为形状为 (batch_size, 1) 的张量。数值例子:若 rewards = [1.0, -0.5],则转换后为形状 (2, 1)。'''rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)'''将下一时刻的状态数据转换为张量。数值例子:例如 next_states = [[1.1, 0.4, -0.1], [0.2, 0.0, 0.9]],形状 (2, 3)。'''next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)'''将 done 标志(0 或 1)转换为形状为 (batch_size, 1) 的张量,用于指示回合是否结束。数值例子:若 dones = [0, 1],则转换后为 tensor([[0.0], [1.0]])。'''dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device)# 和之前章节一样,对倒立摆环境的奖励进行重塑以便训练'''对奖励进行归一化或重塑,使其数值范围更适合训练。对于倒立摆(或类似)环境,原始奖励可能范围较大,这里将所有奖励平移 8.0 后除以 8.0。数值例子:- 如果原始 reward = -8.0,则 ( -8.0 + 8.0) / 8.0 = 0;- 如果 reward = 0,则变为 1;- 如果 reward = 8,则变为 2。'''rewards = (rewards + 8.0) / 8.0# 更新两个Q网络'''调用前面定义的 calc_target 方法,根据重塑后的奖励、下一状态和 done 标志计算 TD 目标。'''td_target = self.calc_target(rewards, next_states, dones)'''计算 critic_1 的均方误差(MSE)损失。- 调用 self.critic_1(states, actions) 得到当前 Q 值估计;- 使用 td_target.detach() 表示目标值不参与梯度计算;- 用 MSE 损失函数计算误差,再取平均。数值例子:- 假设 critic_1 输出 Q 值为 2.5,td_target 为 2.98,则误差为 (2.5−2.98)^2≈0.2304;对 batch 求均值。'''critic_1_loss = torch.mean(F.mse_loss(self.critic_1(states, actions), td_target.detach()))critic_2_loss = torch.mean(F.mse_loss(self.critic_2(states, actions), td_target.detach()))'''清空 critic_1 优化器中所有累积的梯度,防止梯度累加。'''self.critic_1_optimizer.zero_grad()'''对 critic_1 损失进行反向传播,计算每个参数的梯度。'''critic_1_loss.backward()'''更新 critic_1 网络的参数,根据之前计算的梯度和设定的学习率进行一步更新。'''self.critic_1_optimizer.step()'''清空 critic_2 优化器中所有累积的梯度,防止梯度累加。'''self.critic_2_optimizer.zero_grad()'''对 critic_2 损失进行反向传播,计算每个参数的梯度。'''critic_2_loss.backward()'''更新 critic_2 网络的参数,根据之前计算的梯度和设定的学习率进行一步更新。'''self.critic_2_optimizer.step()# 更新策略网络'''使用当前 actor 网络,根据当前状态生成一组新的动作及其对数概率,用于策略更新。'''new_actions, log_prob = self.actor(states)'''计算熵项,即负的对数概率。'''entropy = -log_prob'''用当前 critic_1 网络评估新生成动作的 Q 值。'''q1_value = self.critic_1(states, new_actions)'''用当前 critic_2 网络评估新生成动作的 Q 值。'''q2_value = self.critic_2(states, new_actions)'''计算策略网络(actor)的损失。- 第一项:−𝛼 × entropy 用于鼓励策略探索;- 第二项:−min(𝑞1, 𝑞2) 表示希望选择高价值动作;- 取均值作为整个 batch 的损失。'''actor_loss = torch.mean(-self.log_alpha.exp() * entropy - torch.min(q1_value, q2_value))'''对 actor 网络进行梯度清零、反向传播和参数更新。'''self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 更新alpha值'''计算温度参数 alpha 的损失。- (entropy - self.target_entropy) 表示当前熵与目标熵之间的偏差;- 用 detach() 阻断梯度传递给 entropy(即仅更新 alpha);- 乘以当前的 𝛼 = exp{log(𝛼);- 取均值作为总体损失。'''alpha_loss = torch.mean((entropy - self.target_entropy).detach() * self.log_alpha.exp())'''清空 log_alpha 的梯度、反向传播损失并更新 log_alpha 参数。'''self.log_alpha_optimizer.zero_grad()alpha_loss.backward()self.log_alpha_optimizer.step()'''调用之前定义的 soft_update 方法,对两个目标 Q 网络分别做软更新,使得目标网络参数慢慢跟随当前 Q 网络的更新。'''self.soft_update(self.critic_1, self.target_critic_1)self.soft_update(self.critic_2, self.target_critic_2)env_name = 'Pendulum-v1'
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_bound = env.action_space.high[0]  # 动作最大值
random.seed(0)
np.random.seed(0)
if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)actor_lr = 3e-4
critic_lr = 3e-3
alpha_lr = 3e-4
num_episodes = 100
hidden_dim = 128
gamma = 0.99
tau = 0.005  # 软更新参数
buffer_size = 100000
minimal_size = 1000
batch_size = 64
target_entropy = -env.action_space.shape[0]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")replay_buffer = rl_utils.ReplayBuffer(buffer_size)
agent = SACContinuous(state_dim, hidden_dim, action_dim, action_bound,actor_lr, critic_lr, alpha_lr, target_entropy, tau,gamma, device)return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,replay_buffer, minimal_size,batch_size)    episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()

SAC 算法(离散)

# -*- coding: utf-8 -*-import random
import gym
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt
import rl_utilsclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class QValueNet(torch.nn.Module):''' 只有一层隐藏层的Q网络 '''def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class SAC:''' 处理离散动作的SAC算法 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,alpha_lr, target_entropy, tau, gamma, device):# 策略网络self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)# 第一个Q网络self.critic_1 = QValueNet(state_dim, hidden_dim, action_dim).to(device)# 第二个Q网络self.critic_2 = QValueNet(state_dim, hidden_dim, action_dim).to(device)self.target_critic_1 = QValueNet(state_dim, hidden_dim,action_dim).to(device)  # 第一个目标Q网络self.target_critic_2 = QValueNet(state_dim, hidden_dim,action_dim).to(device)  # 第二个目标Q网络# 令目标Q网络的初始参数和Q网络一样self.target_critic_1.load_state_dict(self.critic_1.state_dict())self.target_critic_2.load_state_dict(self.critic_2.state_dict())self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),lr=critic_lr)self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),lr=critic_lr)# 使用alpha的log值,可以使训练结果比较稳定self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)self.log_alpha.requires_grad = True  # 可以对alpha求梯度self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],lr=alpha_lr)self.target_entropy = target_entropy  # 目标熵的大小self.gamma = gammaself.tau = tauself.device = devicedef take_action(self, state):if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(self.device)probs = self.actor(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()# 计算目标Q值,直接用策略网络的输出概率进行期望计算def calc_target(self, rewards, next_states, dones):next_probs = self.actor(next_states)next_log_probs = torch.log(next_probs + 1e-8)entropy = -torch.sum(next_probs * next_log_probs, dim=1, keepdim=True)q1_value = self.target_critic_1(next_states)q2_value = self.target_critic_2(next_states)min_qvalue = torch.sum(next_probs * torch.min(q1_value, q2_value),dim=1,keepdim=True)next_value = min_qvalue + self.log_alpha.exp() * entropytd_target = rewards + self.gamma * next_value * (1 - dones)return td_targetdef soft_update(self, net, target_net):for param_target, param in zip(target_net.parameters(),net.parameters()):param_target.data.copy_(param_target.data * (1.0 - self.tau) +param.data * self.tau)def update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)  # 动作不再是float类型rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)# 更新两个Q网络td_target = self.calc_target(rewards, next_states, dones)critic_1_q_values = self.critic_1(states).gather(1, actions)critic_1_loss = torch.mean(F.mse_loss(critic_1_q_values, td_target.detach()))critic_2_q_values = self.critic_2(states).gather(1, actions)critic_2_loss = torch.mean(F.mse_loss(critic_2_q_values, td_target.detach()))self.critic_1_optimizer.zero_grad()critic_1_loss.backward()self.critic_1_optimizer.step()self.critic_2_optimizer.zero_grad()critic_2_loss.backward()self.critic_2_optimizer.step()# 更新策略网络probs = self.actor(states)log_probs = torch.log(probs + 1e-8)# 直接根据概率计算熵entropy = -torch.sum(probs * log_probs, dim=1, keepdim=True)  #q1_value = self.critic_1(states)q2_value = self.critic_2(states)min_qvalue = torch.sum(probs * torch.min(q1_value, q2_value),dim=1,keepdim=True)  # 直接根据概率计算期望actor_loss = torch.mean(-self.log_alpha.exp() * entropy - min_qvalue)self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 更新alpha值alpha_loss = torch.mean((entropy - target_entropy).detach() * self.log_alpha.exp())self.log_alpha_optimizer.zero_grad()alpha_loss.backward()self.log_alpha_optimizer.step()self.soft_update(self.critic_1, self.target_critic_1)self.soft_update(self.critic_2, self.target_critic_2)    actor_lr = 1e-3
critic_lr = 1e-2
alpha_lr = 1e-2
num_episodes = 200
hidden_dim = 128
gamma = 0.98
tau = 0.005  # 软更新参数
buffer_size = 10000
minimal_size = 500
batch_size = 64
target_entropy = -1
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v0'
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
replay_buffer = rl_utils.ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = SAC(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, alpha_lr,target_entropy, tau, gamma, device)return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,replay_buffer, minimal_size,batch_size)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()

相关文章:

基于“动手学强化学习”的知识点(一):第 14 章 SAC 算法(gym版本 >= 0.26)

第 14 章 SAC 算法(gym版本 > 0.26) 摘要SAC 算法(连续)SAC 算法(离散) 摘要 本系列知识点讲解基于动手学强化学习中的内容进行详细的疑难点分析!具体内容请阅读动手学强化学习&…...

【QT:信号和槽】

QT信号涉及的三要素:信号源、信号类型、信号的处理方式。 QT的信号槽机制: 给按钮的点击操作关联一个处理函数,用户点击按钮时触发,对应的处理函数就会执行 QT中使用connect函数将信号和槽关联起来,信号触发&#xf…...

Oracle中的INHERIT PRIVILEGES权限

Oracle中的INHERIT PRIVILEGES权限 存储过程和用户函数的AUTHID属性调用者权限vs定义者权限一个简单的示例INHERIT PRIVILEGES权限的含义INHERIT PRIVILEGES权限的安全隐患注意到Oracle 19c数据库中有如下权限信息: SQL> select grantor,grantee,table_name,privilege fro…...

Compose笔记(九)--Checkbox

这一节主要了解一下Compose中的Checkbox,它是Jetpack Compose UI框架中的一个组件,用于创建复选框功能。它允许用户从一个集合中选择一个或多个项目,可以将一个选项打开或关闭。与传统的Android View系统中的Checkbox相比,Compose…...

CSS中粘性定位

1.如何设置为粘性定位? 给元素设置posttion:sticky 即可实现粘性定位. 可以使用left, right ,top, bottom 四个属性调整位置,不过最常用的是top 值. 2.粘性定位的参考点在哪里? 离他最近的一个拥有"滚动机制"的祖先元素,即便这个祖先不是最近的真实可滚动祖先. 3.粘…...

日本IT|AWS工作内容及未来性、以及转职的所需资质和技能

AWSとは AWSはAmazon Web Services(アマゾンウェブサービス)の略称です。 名称から分かるとおり、ネットを通じた通販などを事業として行っているAmazon.com社がクラウドサービスとして運営しています。 本来であれば自分たちでインフラ環境を構築する…...

《Spring日志整合与注入技术:从入门到精通》

1.Spring与日志框架的整合 1.Spring与日志框架进行整合,日志框架就可以在控制台中,输出Spring框架运行过程中的一些重要的信息。 好处:方便了解Spring框架的运行过程,利于程序的调试。 Spring如何整合日志框架 Spring5.x整合log4j…...

如何判断一个项目用的是哪个管理器

如何判断一个项目用的是哪个管理器 npm: 如果项目中存在 package-lock.json 文件,这通常意味着项目使用 npm 作为包管理器。package-lock.json 文件会锁定项目的依赖版本,确保在不同环境中安装相同的依赖。 pnpm: 如果项目中存在 pnpm-lock.yaml 文件&a…...

软件工程概述

软件开发生命周期 软件定义时期:包括可行性研究和详细需求分析,任务是确定软件开发的总目标。 问题定义可行性研究(经济、技术、操作、社会可行性,确定问题和解决办法)需求分析(确定功能需求,性…...

文件系统 linux ─── 第19课

前面博客讲解的是内存级文件管理,接下来介绍磁盘级文件管理 文件系统分为两部分 内存级文件系统 : OS加载进程 ,进程打开文件, OS为文件创建struct file 和文件描述符表 ,将进程与打开的文件相连, struct file 内还函数有指针表, 屏蔽了底层操作的差异,struct file中还有内核级…...

一篇博客搞定时间复杂度

时间复杂度 1、什么是时间复杂度?2、推导大O的规则3、时间复杂度的计算3.1 基础题 13.2 基础题 23.3基础题 33.4进阶题 13.5进阶题 23.6 偏难题 13.7偏难题 2(递归) 前言: 算法在编写成可执行程序后,运行时要耗费时间和…...

微信小程序实现根据不同的用户角色显示不同的tabbar并且可以完整的切换tabbar

直接上图上代码吧 // login/login.js const app getApp() Page({/*** 页面的初始数据*/data: {},/*** 生命周期函数--监听页面加载*/onLoad(options) {},/*** 生命周期函数--监听页面初次渲染完成*/onReady() {},/*** 生命周期函数--监听页面显示*/onShow() {},/*** 生命周期函…...

S_on@atwk的意思

S_onatwk 可能是某种自动化或控制系统中的符号或标记,尤其在PLC(可编程逻辑控制器)编程中,类似的表达方式通常用于表示特定的信号、状态或操作。 我们可以分析这个表达式的各个部分: S_on:通常&#xff0…...

Liunx启动kafka并解决kafka时不时挂掉的问题

kafka启动步骤 先启动zookeeper,启动命令如下 nohup ./zookeeper-server-start.sh /home/kafka/kafka/config/zookeeper.properties > /home/kafka/kafka/zookeeper.log 2>&1 &再启动kafka,启动命令如下 nohup ./kafka-server-start.sh…...

16 | 实现简洁架构的 Store 层

提示: 所有体系课见专栏:Go 项目开发极速入门实战课;欢迎加入 云原生 AI 实战 星球,12 高质量体系课、20 高质量实战项目助你在 AI 时代建立技术竞争力(聚焦于 Go、云原生、AI Infra);本节课最终…...

华为hcia——Datacom实验指南——以太网帧和IPV4数据包格式(一)

实验开始 第一步配置环境 第二步配置客户端 如图所示,我们把客户端的ip配置成192.168.1.10,网关设为192.168.1.1 第三步配置交换机1 system-view sysname LSW1 vlan batch 10 interface ethernet0/0/1 port link-type access port default vlan 10 qu…...

ubuntu软件——视频、截图、图片、菜单自定义等

视频软件,大部分的编码都能适应 sudo apt install vlc图片软件 sudo apt install gwenview截图软件 sudo apt install flameshot设置快捷键 flameshot flameshot gui -p /home/cyun/Pictures/flameshot也就是把它保存到一个自定义的路径 菜单更换 sudo apt r…...

CSS中z-index使用详情

定位层级 1.定位元素的显示层级比普通元素高,无论什么定位,显示层级都是一样的; 2.如果位置发生重叠,默认情况是:后面的元素,会显示在前面元素之上; 3.可以通过CSS属性z-index调整元素的显示层级; 4.z-index的属性值是数字,没有单位,值越大显示层级越高; 5.只有定位的元素…...

qt 自带虚拟键盘的编译使用记录

一、windows 下编译 使用vs 命令窗口,分别执行: qmake CONFIG"lang-en_GB lang-zh_CN" nmake nmake install 如果事先没有 指定需要使用的输入法语言就进行过编译,则需要先 执行 nmake distclean 清理后执行 qmake 才能生效。 …...

杨辉三角形(信息学奥赛一本通-2043)

【题目描述】 例5.11 打印杨辉三角形的前n(2≤n≤20)行。杨辉三角形如下图: 当n5时 1 1 1 1 2 1 1 3 3 1 1 4 6 4 1 输出: 1 1 1 1 2 1 1 3 3 1 1 4 6 4 1 【输入】 输入行数n。 【输出】 输出如题述三角形。n行&#…...

CentOS 7 系统上安装 SQLite

1. 检查系统更新 在安装新软件之前,建议先更新系统的软件包列表,以确保使用的是最新的软件源和补丁。打开终端,执行以下命令: sudo yum update -y -y 选项表示在更新过程中自动回答 “yes”,避免手动确认。 2. 安装 …...

程序化广告行业(13/89):DSP的深入解析与运营要点

程序化广告行业(13/89):DSP的深入解析与运营要点 大家好!一直以来,我都对程序化广告行业保持着浓厚的学习兴趣,在探索的过程中积累了不少心得。今天就想把这些知识分享出来,和大家一起学习进步…...

使用 Doris 和 LakeSoul

作为一种全新的开放式的数据管理架构,湖仓一体(Data Lakehouse)融合了数据仓库的高性能、实时性以及数据湖的低成本、灵活性等优势,帮助用户更加便捷地满足各种数据处理分析的需求,在企业的大数据体系中已经得到越来越…...

datax源码分析

文章目录 前言一、加载配置文件二、根据加载的配置文件进行调度三、根据配置文件执行读取写入任务总结 前言 在上一篇文章当中我们已经了解了datax的启动原理,以及datax的最基础的配置,datax底层java启动类的入口及关键参数。 接下来我将进行启动类执行…...

【HDLbits--分支预测器简单实现】

HDLbits--分支预测器简单实现 1 timer2.branche predicitors3.Branch history shift4.Branch direction predictor 以下是分支预测器的简单其实现; 1 timer 实现一个计时器,当load1’b1时,加载data进去,当load1’b0时进行倒计时&…...

优化Go错误码管理:构建清晰、优雅的HTTP和gRPC错误码规范

在系统开发过程中,如何优雅地管理错误信息一直是个棘手问题。传统的错误处理方式往往存在不统一、难以维护等缺点。而 errcode 模块通过对错误码进行规范化管理,为系统级和业务级错误提供了统一的编码标准。本文将带您深入了解 errcode 的设计原理、错误…...

批量压缩与优化 Excel 文档,减少 Excel 文档大小

当我们在 Excel 文档中插入图片资源的时候,如果我们插入的是原图,可能会导致 Excel 变得非常的大。这非常不利于我们传输或者共享。那么当我们的 Excel 文件非常大的时候,我们就需要对文档做一些压缩或者优化的处理。那有没有什么方法可以实现…...

MongoDB分页实现方式对比:PageRequest vs Skip/Limit

MongoDB分页实现方式对比:PageRequest vs Skip/Limit 一、基本概念1.1 PageRequest分页1.2 Skip/Limit分页 二、主要区别2.1 使用方式2.2 参数计算2.3 适用场景PageRequest适用场景:Skip/Limit适用场景: 三、性能考虑3.1 PageRequest的性能特…...

SAP Commerce(Hybris)营销模块(一):商城产品折扣配置

基于Hybris的Backoffice后台管理系统,创建一个基于模板的营销规则,并配置上对应的优惠活动。 架构设计 先从一张架构图说起 Hybris的促销模块,是基于Promotion引擎来实现的,可以通过Backoffice来进行配置。 通过上面的架构图又可…...

如何在 React 中实现错误边界?

在 React 中实现错误边界 错误边界是 React 提供的一种机制,用于捕获子组件树中的 JavaScript 错误,并展示回退 UI。它可以帮助开发者更好地处理错误,提升用户体验。本文将详细介绍如何在 React 中实现错误边界,包括其工作原理、…...

从头开始开发基于虹软SDK的人脸识别考勤系统(python+RTSP开源)(五)完整源码已上传!

本篇是对照之前代码剩余的部分代码做补充,分享给大家,便于对照运行测试。 完整版的全功能单文件版本已上传!https://download.csdn.net/download/xiaomage_cn/90484179 人脸识别抽象层,这个大家应该都知道,就是为了方…...

PySide(PyQT)的mouseMoveEvent()和hoverMoveEvent()的区别

在 PySide中,mouseMoveEvent 和 hoverMoveEvent 都是用于处理鼠标移动相关操作的事件,但它们之间存在明显的区别: 事件触发条件 • mouseMoveEvent: 当鼠标在对应的图形项(如 QGraphicsPixmapItem&#xff09…...

【通缩螺旋的深度解析与科技破局路径】

通缩螺旋的深度解析与科技破局路径 一、通缩螺旋的形成机制与恶性循环 通缩螺旋(Deflationary Spiral)是经济学中描述价格持续下跌与经济衰退相互强化的动态过程,其核心逻辑可拆解为以下链条: 需求端萎缩:居民消费信…...

【如何使用云服务器与API搭建专属聊天系统:宝塔面板 + Openwebui 完整教程】

文章目录 不挑电脑、不用技术,云服务器 API 轻松搭建专属聊天系统,对接 200 模型,数据全在自己服务器,安全超高一、前置准备:3 分钟快速上手指南云服务器准备相关账号注册 二、手把手部署教程(含代码块&a…...

Oracle数据库存储结构--逻辑存储结构

数据库存储结构:分为物理存储结构和逻辑存储结构。 物理存储结构:操作系统层面如何组织和管理数据 逻辑存储结构:Oracle数据库内部数据组织和管理数据,数据库管理系统层面如何组织和管理数据 Oracle逻辑存储结构 数据库的逻…...

C++ 左值(lvalue)和右值(rvalue)

在 C 中,左值(lvalue)和右值(rvalue)是指对象的不同类别,区分它们对于理解 C 中的表达式求值和资源管理非常重要,尤其在现代 C 中涉及到移动语义(Move Semantics)和完美转…...

《实战AI智能体》DeepSearcher 的架构设计

DeepSearcher 的架构设计 一个通往搜索AGI的Agentic RAG应该如何设计? 从架构上看,DeepSearcher 主要分为两大模块。 一个是数据接入模块,通过Milvus向量数据库来接入各种第三方的私有知识。这也是DeepSearcher相比OpenAI的原本DeepResearc…...

Kotlin 继承

Kotlin 继承 概述 Kotlin 是一种现代的编程语言,它具有简洁、安全、互操作性等特点。在面向对象编程中,继承是一种非常重要的特性,它允许我们创建具有共同属性和方法的类。本文将详细介绍 Kotlin 中的继承机制,包括继承的基本概…...

【6】树状数组学习笔记

前言 树状数组是我学的第一个高级数据结构,属于 log ⁡ \log log 级数据结构。 其实现在一般不会单独考察数据结构,主要是其在其他算法(如贪心,DP)中起到优化作用。 长文警告:本文一共 995 995 995 行…...

【RISCV LAB】0x01-安装实验仿真辅助工具

安装实验辅助工具 实验环境搭建安装 Verilator编译依赖下载源码编译安装测试安装 安装 RISC-V 交叉编译工具链编译依赖下载源码编译安装编译并安装添加环境变量并测试 安装 GTKWave其他模拟器推荐RARSemulsiV FAQ 实验环境搭建 Verilator 是一款开源的支持 Verilog 和 SystemV…...

OSPF-2 邻接建立关系

上一期我们说了OSPF的邻居建立关系以及OSPF邻居关系建立中建立失败的因素以及相关实验案例 这一期我们来说说OSPF的邻接关系建立时需要交互哪些报文以及失败因素及原因和相关实验案例 一、概述 在运行了OSPF的网络当中为了交互链路状态信息和路由信息,互相之间需要建立邻接关…...

操作系统知识点29

1.当用户使用外部设备时,其控制设备的命令传递途径依次为用户应用层->设备独立层->设备驱动层->设备硬件 2.通常用于管理空闲物理内存的方法:空闲快链表法;位示图法;空闲页面表 3. 可用于文件的存取控制和保护的方法&a…...

【Java篇】行云流水,似风分岔:编程结构中的自然法则

文章目录 Java 程序逻辑控制:顺序、分支与循环结构全面解析一、顺序结构二、分支结构2.1 if 语句2.1.1 基本语法2.1.2 if-else 语句2.1.3 if-else if-else 语句 2.2 switch 语句 三、循环结构3.1 while 循环3.2 break 语句3.3 continue 语句3.4 for 循环 四、输入输…...

代码块与设计模式

文章目录 1.代码块1.1基本介绍基本语法 1.2代码块的好处和案例演示1.3代码块使用注意事项和细节讨论!!! 2.单例设计模式2.1什么是设计模式2.2什么是单例模式2.2.1饿汉式2.2.2懒汉式2.2.3比较 1.代码块 1.1基本介绍 代码化块又称为初始化块,属于类中的成员[即是类的一部分]&am…...

要登录的设备ip未知时的处理方法

目录 1 应用场景... 1 2 解决方法:... 1 2.1 wireshark设置... 1 2.2 获取网口mac地址,wireshark抓包前预过滤掉自身mac地址的影响。... 2 2.3 pc网口和设备对接... 3 2.3.1 情况1:... 3 2.3.2 情…...

CentOS 系统安装 docker 以及常用插件

博主用的的是WindTerm软件链接的服务器,因为好用 1.链接上服务器登入后,在/root/目录下 2.执行以下命令安装docker sudo yum install -y yum-utilssudo yum-config-manager \--add-repo \https://download.docker.com/linux/centos/docker-ce.reposudo…...

统计字符(字符串)(gets与fgets的区别)

统计字符 #include<stdio.h> #include<string.h> int main(){char str1[5],str2[80];while(gets(str1)){if(strcmp(str1,"#")0)break;gets(str2);for(int i0;i<strlen(str1);i){int sum0;for(int j0;j<strlen(str2);j){if(str1[i]str2[j])sum;}p…...

Node.js REPL 深入解析

Node.js REPL 深入解析 引言 Node.js 作为一种流行的 JavaScript 运行环境,在服务器端开发中扮演着重要角色。REPL(Read-Eval-Print Loop,读取-求值-打印循环)是 Node.js 的一个核心特性,它允许开发者在一个交互式环境中执行 JavaScript 代码。本文将深入探讨 Node.js R…...

【测试语言基础篇】Python基础之List列表

一、Python 列表(List) 序列是Python中最基本的数据结构。序列中的每个元素都分配一个数字 - 它的位置&#xff0c;或索引&#xff0c;第一个索引是0&#xff0c;第二个索引是1&#xff0c;依此类推。 Python有6个序列的内置类型&#xff0c;但最常见的是列表和元组。序列都可…...

中山六院团队发表可解释多模态融合模型Brim,可以在缺少分子数据时借助病理图像模拟生成伪基因组特征|顶刊解读·25-02-14

小罗碎碎念 在癌症诊疗领域&#xff0c;精准预测患者预后对临床决策意义重大。传统的癌症分期系统&#xff0c;如TNM分期&#xff0c;因无法充分考量肿瘤异质性&#xff0c;难以准确预测患者的临床结局。而基于人工智能的多模态融合模型虽有潜力&#xff0c;但在实际临床应用中…...