Source code for lagom.metric.td

import numpy as np

from lagom.utils import numpify


[docs]def td0_target(gamma, rewards, Vs, last_V, reach_terminal): r"""Calculate TD(0) targets of a batch of episodic transitions. Let :math:`r_1, r_2, \dots, r_T` be a list of rewards and let :math:`V(s_0), V(s_1), \dots, V(s_{T-1}), V(s_{T})` be a list of state values including a last state value. Let :math:`\gamma` be a discounted factor, the TD(0) targets are calculated as follows .. math:: r_t + \gamma V(s_t), \forall t = 1, 2, \dots, T .. note:: The state values for terminal states are masked out as zero ! """ rewards = numpify(rewards, np.float32) Vs = numpify(Vs, np.float32) last_V = numpify(last_V, np.float32) if reach_terminal: Vs = np.append(Vs, 0.0) else: Vs = np.append(Vs, last_V) out = rewards + gamma*Vs[1:] return out.astype(np.float32)
[docs]def td0_error(gamma, rewards, Vs, last_V, reach_terminal): r"""Calculate TD(0) errors of a batch of episodic transitions. Let :math:`r_1, r_2, \dots, r_T` be a list of rewards and let :math:`V(s_0), V(s_1), \dots, V(s_{T-1}), V(s_{T})` be a list of state values including a last state value. Let :math:`\gamma` be a discounted factor, the TD(0) errors are calculated as follows .. math:: \delta_t = r_{t+1} + \gamma V(s_{t+1}) - V(s_t) .. note:: The state values for terminal states are masked out as zero ! """ rewards = numpify(rewards, np.float32) Vs = numpify(Vs, np.float32) last_V = numpify(last_V, np.float32) if reach_terminal: Vs = np.append(Vs, 0.0) else: Vs = np.append(Vs, last_V) out = rewards + gamma*Vs[1:] - Vs[:-1] return out.astype(np.float32)