diff --git a/scripts/gyms/dribble.py b/scripts/gyms/dribble.py index c78d9a2..0211373 100644 --- a/scripts/gyms/dribble.py +++ b/scripts/gyms/dribble.py @@ -259,6 +259,7 @@ class dribble(gym.Env): def close(self): Draw.clear_all() self.player.terminate() + def execute(self, action): # Actions: @@ -354,7 +355,7 @@ class dribble(gym.Env): loss = self.loss(obs, action_p, action_r) # 计算奖励 - reward = np.linalg.norm(w.ball_cheat_abs_vel) * cos_theta + reward = np.linalg.norm(w.ball_cheat_abs_vel) * cos_theta + loss if self.ball_dist_hip_center_2d < 0.115: reward = 0