Skip to content

kfac-jax 0.0.3

Compare
Choose a tag to compare
@botev botev released this 23 Sep 20:12
· 185 commits to main since this release

What's Changed

  • Changing the version in the citation text in the README. by @copybara-service in #29
  • Adding attributes for the number of training and evaluation devices. by @copybara-service in #31
  • Adding some methods to ImplicitExactCurvature by @copybara-service in #32
  • Adding "put_stop_grad_on_loss_factor" argument to 'multiply_fisher_factor'. by @copybara-service in #36
  • Making ScaleAndShift blocks begin capable of having parameters that are broadcast by construction, e.g. batch norm with scale parameters [1, 1, 1, d]. by @copybara-service in #33
    • Changing jax.tree_map -> jax.tree_util.tree_map and related due to recent deprecation. by @copybara-service in #37
    • Removed unused precedence argument from GraphPattern. by @copybara-service in #38
  • Fix a small bug where we don't check in the jaxpr constvars. by @copybara-service in #39
    • Adding an estimator attribute to the optimizer. by @copybara-service in #34
  • Updating the docs to correctly refer to update_cache. by @copybara-service in #40
  • Compare with slightly less numerical precision. by @copybara-service in #41
    • Revamping the graph matching code to be able to detect layers and register tag in arbitrary higher-order Jax primitives. by @copybara-service in #42
  • Revising docstring for optimizer class. Now contains missing details about value_and_grad_func. by @copybara-service in #43
  • Internal change. by @copybara-service in #44
    • Make LossTag to return only the parameter dependent arrays. by @copybara-service in #46
    • Improving LossTags to be able to deal correctly with None arguments, by passing in argument names. by @copybara-service in #47
  • Minor fix to a bug introduced on previous commit. by @copybara-service in #48
    • Correcting issues with docstring for optimizer. by @copybara-service in #45
  • Fixing a bug in the graph matcher introduced in a recent CL. by @copybara-service in #49
  • Removing unneeded jax.jit in get_mean and get_sum. by @copybara-service in #50
    • Adding per-parameter norm stats to optimizer by @copybara-service in #51
  • Allowing the pi-adjusted psd inverse to accept diagonal factors. by @copybara-service in #55
  • Fixing wrong type annotation of pmap_axis_name. by @copybara-service in #56
  • Adding optional offloading of eigh computation to the host because of a bug in CUDA 11.7.0 cuSOLVER library. by @copybara-service in #57

Full Changelog: v0.0.2...v0.0.3