Source code for lagom.metric.vtrace

import numpy as np

from lagom.utils import numpify

from .td import td0_error


[docs]def vtrace(behavior_logprobs, target_logprobs, gamma, Rs, Vs, last_V, reach_terminal, clip_rho=1.0, clip_pg_rho=1.0): behavior_logprobs = numpify(behavior_logprobs, np.float32) target_logprobs = numpify(target_logprobs, np.float32) Rs = numpify(Rs, np.float32) Vs = numpify(Vs, np.float32) last_V = numpify(last_V, np.float32) assert all([item.ndim == 1 for item in [behavior_logprobs, target_logprobs, Rs, Vs]]) assert np.isscalar(gamma) rhos = np.exp(target_logprobs - behavior_logprobs) clipped_rhos = np.minimum(clip_rho, rhos) cs = np.minimum(1.0, rhos) deltas = clipped_rhos*td0_error(gamma, Rs, Vs, last_V, reach_terminal) vs_minus_V = [] total = 0.0 for delta_t, c_t in zip(deltas[::-1], cs[::-1]): total = delta_t + gamma*c_t*total vs_minus_V.append(total) vs_minus_V = np.asarray(vs_minus_V)[::-1] vs = vs_minus_V + Vs vs_next = np.append(vs[1:], (1. - reach_terminal)*last_V) clipped_pg_rhos = np.minimum(clip_pg_rho, rhos) As = clipped_pg_rhos*(Rs + gamma*vs_next - Vs) return vs, As