rv.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. """Module for random variables.
  2. This module contains classes for random variables and exceptions.
  3. Classes:
  4. ParameterException: Exception for invalid parameters.
  5. RandomVariable: Abstract class for random variables.
  6. Discrete: Class for discrete random variables.
  7. Gaussian: Class for Gaussian random variables.
  8. """
  9. from abc import ABC, abstractmethod, abstractproperty, abstractclassmethod
  10. import numpy as np
  11. class ParameterException(Exception):
  12. """Exception for invalid parameters."""
  13. pass
  14. class RandomVariable(ABC):
  15. """Abstract base class for all random variables."""
  16. @abstractclassmethod
  17. def unity(cls, *args):
  18. pass
  19. @abstractproperty
  20. def dim(self):
  21. pass
  22. @abstractmethod
  23. def __str__(self):
  24. pass
  25. @abstractmethod
  26. def __add__(self):
  27. pass
  28. @abstractmethod
  29. def __sub__(self):
  30. pass
  31. @abstractmethod
  32. def __mul__(self):
  33. pass
  34. @abstractmethod
  35. def __iadd__(self):
  36. pass
  37. @abstractmethod
  38. def __isub__(self):
  39. pass
  40. @abstractmethod
  41. def __imul__(self):
  42. pass
  43. @abstractmethod
  44. def __eq__(self):
  45. pass
  46. @abstractmethod
  47. def normalize(self):
  48. pass
  49. @abstractmethod
  50. def marginalize(self):
  51. pass
  52. @abstractmethod
  53. def maximize(self):
  54. pass
  55. @abstractmethod
  56. def argmax(self):
  57. pass
  58. @abstractmethod
  59. def log(self):
  60. pass
  61. class Discrete(RandomVariable):
  62. """Class for discrete random variables.
  63. A discrete random variable is defined by a single- or multi-dimensional
  64. probability mass function. In addition, each dimension of the probability
  65. mass function has to be associated with a variable. The variable is
  66. represented by a variable node of the comprehensive factor graph.
  67. """
  68. def __init__(self, raw_pmf, *args):
  69. """Initialize a discrete random variable.
  70. Create a new discrete random variable with the given probability
  71. mass function over the given variable nodes.
  72. Args:
  73. raw_pmf: A Numpy array representing the probability mass function.
  74. The probability mass function does not need to be normalized.
  75. *args: Instances of the class VNode representing the variables of
  76. the probability mass function. The number of the positional
  77. arguments must match the number of dimensions of the Numpy
  78. array.
  79. Raises:
  80. ParameterException: An error occurred initializing with invalid
  81. parameters.
  82. """
  83. pmf = np.asarray(raw_pmf, dtype=np.float64)
  84. # Set probability mass function
  85. self._pmf = pmf
  86. # Set variable nodes for dimensions
  87. if np.ndim(pmf) != len(args):
  88. raise ParameterException('Dimension mismatch.')
  89. else:
  90. self._dim = args
  91. @classmethod
  92. def unity(cls, *args):
  93. """Initialize unit element of a discrete random variable.
  94. Args:
  95. *args: Instances of the class VNode representing the variables of
  96. the probability mass function. The number of the positional
  97. arguments must match the number of dimensions of the Numpy
  98. array.
  99. Raises:
  100. ParameterException: An error occurred initializing with invalid
  101. parameters.
  102. """
  103. n = len(args)
  104. return cls(np.ones((1,) * n), *args)
  105. @property
  106. def pmf(self):
  107. return self._pmf
  108. @property
  109. def dim(self):
  110. return self._dim
  111. def __str__(self):
  112. """Return string representation of the discrete random variable."""
  113. return str(self.pmf)
  114. def __add__(self, other):
  115. """Add other to self and return the result.
  116. Args:
  117. other: Summand for the discrete random variable.
  118. Returns:
  119. A new discrete random variable representing the summation.
  120. """
  121. # Verify dimensions of summand and summand.
  122. if len(self.dim) < len(other.dim):
  123. self._expand(other.dim, other.pmf.shape)
  124. elif len(self.dim) > len(other.dim):
  125. other._expand(self.dim, self.pmf.shape)
  126. pmf = self.pmf + other.pmf
  127. return Discrete(pmf, *self.dim)
  128. def __sub__(self, other):
  129. """Subtract other from self and return the result.
  130. Args:
  131. other: Subtrahend for the discrete random variable.
  132. Returns:
  133. A new discrete random variable representing the subtraction.
  134. """
  135. # Verify dimensions of minuend and subtrahend.
  136. if len(self.dim) < len(other.dim):
  137. self._expand(other.dim, other.pmf.shape)
  138. elif len(self.dim) > len(other.dim):
  139. other._expand(self.dim, self.pmf.shape)
  140. pmf = self.pmf - other.pmf
  141. return Discrete(pmf, *self.dim)
  142. def __mul__(self, other):
  143. """Multiply other with self and return the result.
  144. Args:
  145. other: Multiplier for the discrete random variable.
  146. Returns:
  147. A new discrete random variable representing the multiplication.
  148. """
  149. # Verify dimensions of multiplicand and multiplier.
  150. if len(self.dim) < len(other.dim):
  151. self._expand(other.dim, other.pmf.shape)
  152. elif len(self.dim) > len(other.dim):
  153. other._expand(self.dim, self.pmf.shape)
  154. pmf = self.pmf * other.pmf
  155. return Discrete(pmf, *self.dim)
  156. def __iadd__(self, other):
  157. """Method for augmented addition.
  158. Args:
  159. other: Summand for the discrete random variable.
  160. Returns:
  161. A new discrete random variable representing the summation.
  162. """
  163. return self.__add__(other)
  164. def __isub__(self, other):
  165. """Method for augmented subtraction.
  166. Args:
  167. other: Subtrahend for the discrete random variable.
  168. Returns:
  169. A new discrete random variable representing the subtraction.
  170. """
  171. return self.__sub__(other)
  172. def __imul__(self, other):
  173. """Method for augmented multiplication.
  174. Args:
  175. other: Multiplier for the discrete random variable.
  176. Returns:
  177. A new discrete random variable representing the multiplication.
  178. """
  179. return self.__mul__(other)
  180. def __eq__(self, other):
  181. """Compare self with other and return the boolean result.
  182. Two discrete random variables are equal only if the probability mass
  183. functions are equal and the order of dimensions are equal.
  184. """
  185. return np.allclose(self.pmf, other.pmf) \
  186. and self.dim == other.dim
  187. def _expand(self, dims, states):
  188. """Expand dimensions.
  189. Expand the discrete random variable along the given new dimensions.
  190. Args:
  191. dims: List of discrete random variables.
  192. """
  193. reps = [1, ] * len(dims)
  194. # Extract missing dimensions
  195. diff = [i for i, d in enumerate(dims) if d not in self.dim]
  196. # Expand missing dimensions
  197. for d in diff:
  198. self._pmf = np.expand_dims(self.pmf, axis=d)
  199. reps[d] = states[d]
  200. # Repeat missing dimensions
  201. self._pmf = np.tile(self.pmf, reps)
  202. self._dim = dims
  203. def normalize(self):
  204. """Normalize probability mass function."""
  205. pmf = self.pmf / np.sum(self.pmf)
  206. return Discrete(pmf, *self.dim)
  207. def marginalize(self, *dims, normalize=True):
  208. """Return the marginal for given dimensions.
  209. The probability mass function of the discrete random variable
  210. is marginalized along the given dimensions.
  211. Args:
  212. *dims: Instances of discrete random variables, which should be
  213. marginalized out.
  214. normalize: Boolean flag if probability mass function should be
  215. normalized after marginalization.
  216. Returns:
  217. A new discrete random variable representing the marginal.
  218. """
  219. axis = tuple(idx for idx, d in enumerate(self.dim) if d in dims)
  220. pmf = np.sum(self.pmf, axis)
  221. if normalize:
  222. pmf /= np.sum(pmf)
  223. new_dims = tuple(d for d in self.dim if d not in dims)
  224. return Discrete(pmf, *new_dims)
  225. def maximize(self, *dims, normalize=True):
  226. """Return the maximum for given dimensions.
  227. The probability mass function of the discrete random variable
  228. is maximized along the given dimensions.
  229. Args:
  230. *dims: Instances of discrete random variables, which should be
  231. maximized out.
  232. normalize: Boolean flag if probability mass function should be
  233. normalized after marginalization.
  234. Returns:
  235. A new discrete random variable representing the maximum.
  236. """
  237. axis = tuple(idx for idx, d in enumerate(self.dim) if d in dims)
  238. pmf = np.amax(self.pmf, axis)
  239. if normalize:
  240. pmf /= np.sum(pmf)
  241. new_dims = tuple(d for d in self.dim if d not in dims)
  242. return Discrete(pmf, *new_dims)
  243. def argmax(self, dim=None):
  244. """Return the dimension index of the maximum.
  245. Args:
  246. dim: An optional discrete random variable along a marginalization
  247. should be performed and the maximum is searched over the
  248. remaining dimensions. In the case of None, the maximum is
  249. search along all dimensions.
  250. Returns:
  251. An integer representing the dimension of the maximum.
  252. """
  253. if dim is None:
  254. return np.unravel_index(self.pmf.argmax(), self.pmf.shape)
  255. m = self.marginalize(dim)
  256. return np.argmax(m.pmf)
  257. def log(self):
  258. """Natural logarithm of the discrete random variable.
  259. Returns:
  260. A new discrete random variable with the natural logarithm of the
  261. probablitiy mass function.
  262. """
  263. return Discrete(np.log(self.pmf), *self.dim)
  264. class Gaussian(RandomVariable):
  265. """Class for Gaussian random variables.
  266. A Gaussian random variable is defined by a mean vector and a covariance
  267. matrix. In addition, each dimension of the mean vector and the covariance
  268. matrix has to be associated with a variable. The variable is
  269. represented by a variable node of the comprehensive factor graph.
  270. """
  271. def __init__(self, raw_mean, raw_cov, *args):
  272. """Initialize a Gaussian random variable.
  273. Create a new Gaussian random variable with the given mean vector and
  274. the given covariance matrix over the given variable nodes.
  275. Args:
  276. raw_mean: A Numpy array representing the mean vector.
  277. raw_cov: A Numpy array representing the covariance matrix.
  278. *args: Instances of the class VNode representing the variables of
  279. the mean vector and covariance matrix, respectively. The number
  280. of the positional arguments must match the number of dimensions
  281. of the Numpy arrays.
  282. Raises:
  283. ParameterException: An error occurred initializing with invalid
  284. parameters.
  285. """
  286. if raw_mean is not None and raw_cov is not None:
  287. mean = np.asarray(raw_mean, dtype=np.float64)
  288. cov = np.asarray(raw_cov, dtype=np.float64)
  289. # Set mean vector and covariance matrix
  290. if mean.shape[0] != cov.shape[0]:
  291. raise ParameterException('Dimension mismatch.')
  292. else:
  293. # Precision matrix
  294. self._W = np.linalg.inv(np.asarray(cov))
  295. # Precision-mean vector
  296. self._Wm = np.dot(self._W, np.asarray(mean))
  297. # Set variable nodes for dimensions
  298. if cov.shape[0] != len(args):
  299. raise ParameterException('Dimension mismatch.')
  300. else:
  301. self._dim = args
  302. else:
  303. self._dim = args
  304. @classmethod
  305. def unity(cls, *args):
  306. """Initialize unit element of a Gaussian random variable.
  307. Args:
  308. *args: Instances of the class VNode representing the variables of
  309. the mean vector and covariance matrix, respectively. The number
  310. of the positional arguments must match the number of dimensions
  311. of the Numpy arrays.
  312. Raises:
  313. ParameterException: An error occurred initializing with invalid
  314. parameters.
  315. """
  316. n = len(args)
  317. return cls(np.diag(np.zeros(n)), np.diag(np.ones(n) * np.Inf), *args)
  318. @classmethod
  319. def inf_form(cls, raw_W, raw_Wm, *args):
  320. """Initialize a Gaussian random variable using the information form.
  321. Create a new Gaussian random variable with the given mean vector and
  322. the given covariance matrix over the given variable nodes.
  323. Args:
  324. raw_W: A Numpy array representing the precision matrix.
  325. raw_Wm: A Numpy array representing the precision-mean vector.
  326. *args: Instances of the class VNode representing the variables of
  327. the mean vector and covariance matrix, respectively. The number
  328. of the positional arguments must match the number of dimensions
  329. of the Numpy arrays.
  330. Raises:
  331. ParameterException: An error occurred initializing with invalid
  332. parameters.
  333. """
  334. g = cls(None, None, *args)
  335. g._W = np.asarray(raw_W, dtype=np.float64)
  336. g._Wm = np.asarray(raw_Wm, dtype=np.float64)
  337. return g
  338. @property
  339. def mean(self):
  340. return np.dot(np.linalg.inv(self._W), self._Wm)
  341. @property
  342. def cov(self):
  343. return np.linalg.inv(self._W)
  344. @property
  345. def dim(self):
  346. return self._dim
  347. def __str__(self):
  348. """Return string representation of the Gaussian random variable."""
  349. return "%s %s" % (self.mean, self.cov)
  350. def __add__(self, other):
  351. """Add other to self and return the result.
  352. Args:
  353. other: Summand for the Gaussian random variable.
  354. Returns:
  355. A new Gaussian random variable representing the summation.
  356. """
  357. return Gaussian(self.mean + other.mean,
  358. self.cov + other.cov,
  359. *self.dim)
  360. def __sub__(self, other):
  361. """Subtract other from self and return the result.
  362. Args:
  363. other: Subrahend for the Gaussian random variable.
  364. Returns:
  365. A new Gaussian random variable representing the subtraction.
  366. """
  367. return Gaussian(self.mean - other.mean,
  368. self.cov - other.cov,
  369. *self.dim)
  370. def __mul__(self, other):
  371. """Multiply other with self and return the result.
  372. Args:
  373. other: Multiplier for the Gaussian random variable.
  374. Returns:
  375. A new Gaussian random variable representing the multiplication.
  376. """
  377. W = self._W + other._W
  378. Wm = self._Wm + other._Wm
  379. return Gaussian.inf_form(W, Wm, *self.dim)
  380. def __iadd__(self, other):
  381. """Method for augmented addition.
  382. Args:
  383. other: Summand for the Gaussian random variable.
  384. Returns:
  385. A new Gaussian random variable representing the summation.
  386. """
  387. return self.__add__(other)
  388. def __isub__(self, other):
  389. """Method for augmented subtraction.
  390. Args:
  391. other: Subtrahend for the Gaussian random variable.
  392. Returns:
  393. A new Gaussian random variable representing the subtraction.
  394. """
  395. return self.__sub__(other)
  396. def __imul__(self, other):
  397. """Method for augmented multiplication.
  398. Args:
  399. other: Multiplier for the Gaussian random variable.
  400. Returns:
  401. A new Gaussian random variable representing the multiplication.
  402. """
  403. return self.__mul__(other)
  404. def __eq__(self, other):
  405. """Compare self with other and return the boolean result.
  406. Two Gaussian random variables are equal only if the mean vectors and
  407. the covariance matrices are equal and the order of dimensions are
  408. equal.
  409. """
  410. return np.allclose(self._W, other._W) \
  411. and np.allclose(self._Wm, other._Wm) \
  412. and self.dim == other.dim
  413. def normalize(self):
  414. """Normalize probability density function."""
  415. return self
  416. def marginalize(self, *dims):
  417. """Return the marginal for given dimensions.
  418. The probability density function of the Gaussian random variable
  419. is marginalized along the given dimensions.
  420. Args:
  421. *dims: Instances of Gaussian random variables, which should be
  422. marginalized out.
  423. Returns:
  424. A new Gaussian random variable representing the marginal.
  425. """
  426. axis = tuple(idx for idx, d in enumerate(self.dim) if d not in dims)
  427. mean = self.mean[np.ix_(axis, [0])]
  428. cov = self.cov[np.ix_(axis, axis)]
  429. new_dims = tuple(d for d in self.dim if d not in dims)
  430. return Gaussian(mean, cov, *new_dims)
  431. def maximize(self, *dims):
  432. """Return the maximum for given dimensions.
  433. The probability density function of the Gaussian random variable
  434. is maximized along the given dimensions.
  435. Args:
  436. *dims: Instances of Gaussian random variables, which should be
  437. maximized out.
  438. Returns:
  439. A new Gaussian random variable representing the maximum.
  440. """
  441. axis = tuple(idx for idx, d in enumerate(self.dim) if d not in dims)
  442. mean = self.mean[np.ix_(axis, [0])]
  443. cov = self.cov[np.ix_(axis, axis)]
  444. new_dims = tuple(d for d in self.dim if d not in dims)
  445. return Gaussian(mean, cov, *new_dims)
  446. def argmax(self, dim=None):
  447. """Return the dimension index of the maximum.
  448. Args:
  449. dim: An optional Gaussian random variable along a marginalization
  450. should be performed and the maximum is searched over the
  451. remaining dimensions. In the case of None, the maximum is
  452. search along all dimensions.
  453. Returns:
  454. An integer representing the dimension of the maximum.
  455. """
  456. if dim is None:
  457. return self.mean
  458. m = self.marginalize(dim)
  459. return m.mean
  460. def log(self):
  461. """Natural logarithm of the Gaussian random variable.
  462. Returns:
  463. A new Gaussian random variable with the natural logarithm of the
  464. probability density function.
  465. """
  466. raise NotImplementedError