From 7f0b3a7cdc009d533e55ee43037d973eb384ac8b Mon Sep 17 00:00:00 2001 From: Tom Body Date: Tue, 31 Jan 2023 14:31:29 -0500 Subject: [PATCH 1/2] Experimental: add wraps_ufunc decorator --- pint_xarray/wraps_ufunc.py | 206 +++++++++++++++++++++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 pint_xarray/wraps_ufunc.py diff --git a/pint_xarray/wraps_ufunc.py b/pint_xarray/wraps_ufunc.py new file mode 100644 index 00000000..81ceb11d --- /dev/null +++ b/pint_xarray/wraps_ufunc.py @@ -0,0 +1,206 @@ +import pint_xarray +import xarray as xr +import numpy as np + +import functools +from inspect import signature, Parameter +import warnings +from .accessors import default_registry +from pint import Quantity, UnitStrippedWarning +from typing import Union + +_handled_types = (float, int, np.ndarray) + +def _check_wrapper_args(unit_iterable): + """ + Convert the "unit_iterable" to a tuple if it isn't already an iterable. + Make sure that the elements in the unit_iterable are either Units or strings. + """ + for arg in unit_iterable: + if arg is not None and not isinstance(arg, (default_registry.Unit, str)): + raise TypeError(f"wraps arguments must by of type str or Unit, not {type(arg)} ({arg})") + +def _convert_units(key, value, unit, debug=False): + """Convert "value" such that it has units of "unit". + + If unit=None: don't convert the value. + If unit="": allow non-dimensional inputs like floats, ints, arrays + """ + if debug: print(f"Converting {key} with value {value} to unit {unit}") + + if unit is None: + # Do nothing + return value + + dimensionless = (unit == "") or (unit == default_registry.dimensionless) + + if not dimensionless and not isinstance(value, (xr.DataArray, Quantity)): + raise TypeError(f"Input for {key} is of type {type(value)} which does not contain units, but units of {unit} are required for {key}.") + + if isinstance(value, xr.DataArray): + if value.pint.units is None: + value = value.pint.quantify("") + return value.pint.to(unit) + elif isinstance(value, Quantity): + return value.to(unit).magnitude + elif dimensionless and isinstance(value, _handled_types): + return value + + # Catch unhandled case + raise NotImplementedError(f"Could not process input for {key} with type={type(value)}.") + +def _set_units(value, unit, key): + """Set the units of "value" to "unit" (or leave as is if unit=None).""" + if not default_registry.force_ndarray_like and not default_registry.force_ndarray: + raise ValueError("Set 'force_ndarray_like' or 'force_ndarray' when defining default_registry=UnitRegistry(...).") + + if unit is None: + return value + + dimensionless = (unit == "") or (unit == default_registry.dimensionless) + + if dimensionless and isinstance(value, _handled_types): + return value + + if isinstance(value, xr.DataArray): + if value.pint.units is None: + return value.pint.quantify(unit, unit_registry=default_registry) + else: + return value.pint.to(unit) + + elif isinstance(value, Quantity): + return value.to(unit) + + elif isinstance(value, _handled_types): + return Quantity(value, unit) + + # Catch unhandled case + raise NotImplementedError(f"Could not process output with type={type(value)} for output {key}.") + +def _get_default_values(func): + """Get default values for arguments.""" + func_signature = signature(func) + + return { + k: v.default + for k, v in func_signature.parameters.items() + if v.default is not Parameter.empty + } + +def wraps_ufunc(_func=None, *, return_units: dict[str, str], input_units: dict[str, str], ufunc_kwargs_from_wrapper=dict(), auto_none: dict=False): + """Wraps a function to accept pint-xarray inputs and then uses xarrays apply_ufunc to iterate over dimensions. + + You can pass arguments to apply_ufunc (see https://docs.xarray.dev/en/stable/generated/xarray.apply_ufunc.html) + by supplying an additional ufunc_kwargs key-word argument to the function, or by setting + ufunc_kwargs_from_wrapper when calling @wraps_ufunc. + + For each element in return_units and input_units, you can set + * None: do nothing with this argument + * "" or default_registry.dimensionless: allow untyped floats, ints and np.arrays. Ensure Quantity and xr.DataArrays units + reduce to dimensionless. + * "" or default_registry.: require Quantity or xr.DataArray, convert their units to . + + Closely based on pint.registry_helpers.wraps. + """ + _check_wrapper_args(input_units) + _check_wrapper_args(return_units) + + + if "vectorize" not in ufunc_kwargs_from_wrapper.keys(): + # Unless explicitly set to False, assume that the user wants vectorize=True. + ufunc_kwargs_from_wrapper["vectorize"] = True + if "output_core_dims" not in ufunc_kwargs_from_wrapper.keys(): + # Unless explicitly set, assume that the user doesn't want to add new dimensions to the output. + ufunc_kwargs_from_wrapper["output_core_dims"] = [() for i in range(len(return_units.keys()))] + + def decorator(func): + + # Work out what the function is expecting as inputs + func_parameters = signature(func).parameters + + if auto_none: + for parameter in func_parameters: + if not parameter in input_units.keys(): + # Only need to include units which you want to perform unit-checking on. + input_units[parameter] = None + + # Make sure that there are as many units in "input_units" as there are arguments + count_params = len(func_parameters) + if len(input_units) != count_params: + raise TypeError(f"{func.__name__} takes {count_params} parameters, but {len(input_units)} units were passed") + + for parameter in input_units.keys(): + if not parameter in func_parameters: + raise TypeError(f"{func.__name__} does not have a parameter {parameter}") + + @functools.wraps(func) + def wrapper(*positional_args, + ufunc_kwargs=dict(), + debug=False, + **keyword_args): + + # Should we copy kwargs? Safer since no global modification, but + # also requires more memory. + + # Convert all positional arguments into keyword arguments + for arg, param in zip(positional_args, func_parameters): + if param in keyword_args: + raise KeyError(f"Repeated argument for {param}") + keyword_args[param] = arg + + # Add default values if they aren't already set + keyword_args = {**_get_default_values(func), **keyword_args} + + if debug: print(f"{func} called with {keyword_args}") + + # Convert the input into the desired units + for key in func_parameters: + value = keyword_args[key] + unit = input_units[key] + + keyword_args[key] = _convert_units(key, value, unit, debug=debug) + + ufunc_kwargs = {**ufunc_kwargs_from_wrapper, **ufunc_kwargs} + + with warnings.catch_warnings(): + # Suppress the UnitStrippedWarning — we want to drop units and down-cast to unitless arrays since + # this is what the pint.registry_helpers.wraps decorator does. + warnings.simplefilter("ignore", category=UnitStrippedWarning) + + if debug: print(f"apply_ufunc called with {keyword_args}, with kwargs {ufunc_kwargs}") + + function_return = xr.apply_ufunc( + func, + *[keyword_args[param] for param in func_parameters], + **ufunc_kwargs + ) + + if debug: print(f"apply_ufunc returned {function_return}") + + # Convert the output into the desired units + if isinstance(function_return, tuple): + # Multiple return + if len(return_units) != len(function_return): + raise TypeError(f"{func.__name__} returned {len(function_return)} values(s), but {len(return_units)} units were passed") + + new_function_return = list(function_return) + + for i, (returned, unit, key) in enumerate(zip(function_return, return_units.values(), return_units.keys())): + new_function_return[i] = _set_units(returned, unit, key) + + return tuple(new_function_return) + + else: + if len(return_units) == 0 and isinstance(function_return, xr.DataArray) and np.all(function_return == None): + return None + elif len(return_units) != 1: + raise TypeError(f"{func.__name__} returned {len(function_return)} value(s), but {len(return_units)} units were passed") + + return _set_units(function_return, unit=list(return_units.values())[0], key=list(return_units.keys())[0]) + + return wrapper + + if _func is None: + return decorator + else: + return decorator(_func) From cc69b61409d773f409f047c28fca5b2644bb75f6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Jan 2023 19:32:27 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pint_xarray/wraps_ufunc.py | 150 ++++++++++++++++++++++++------------- 1 file changed, 98 insertions(+), 52 deletions(-) diff --git a/pint_xarray/wraps_ufunc.py b/pint_xarray/wraps_ufunc.py index 81ceb11d..973af1b7 100644 --- a/pint_xarray/wraps_ufunc.py +++ b/pint_xarray/wraps_ufunc.py @@ -1,16 +1,19 @@ -import pint_xarray -import xarray as xr -import numpy as np - import functools -from inspect import signature, Parameter import warnings -from .accessors import default_registry -from pint import Quantity, UnitStrippedWarning +from inspect import Parameter, signature from typing import Union +import numpy as np +import xarray as xr +from pint import Quantity, UnitStrippedWarning + +import pint_xarray + +from .accessors import default_registry + _handled_types = (float, int, np.ndarray) + def _check_wrapper_args(unit_iterable): """ Convert the "unit_iterable" to a tuple if it isn't already an iterable. @@ -18,24 +21,30 @@ def _check_wrapper_args(unit_iterable): """ for arg in unit_iterable: if arg is not None and not isinstance(arg, (default_registry.Unit, str)): - raise TypeError(f"wraps arguments must by of type str or Unit, not {type(arg)} ({arg})") + raise TypeError( + f"wraps arguments must by of type str or Unit, not {type(arg)} ({arg})" + ) + def _convert_units(key, value, unit, debug=False): """Convert "value" such that it has units of "unit". - + If unit=None: don't convert the value. If unit="": allow non-dimensional inputs like floats, ints, arrays """ - if debug: print(f"Converting {key} with value {value} to unit {unit}") - + if debug: + print(f"Converting {key} with value {value} to unit {unit}") + if unit is None: # Do nothing return value - + dimensionless = (unit == "") or (unit == default_registry.dimensionless) if not dimensionless and not isinstance(value, (xr.DataArray, Quantity)): - raise TypeError(f"Input for {key} is of type {type(value)} which does not contain units, but units of {unit} are required for {key}.") + raise TypeError( + f"Input for {key} is of type {type(value)} which does not contain units, but units of {unit} are required for {key}." + ) if isinstance(value, xr.DataArray): if value.pint.units is None: @@ -45,37 +54,45 @@ def _convert_units(key, value, unit, debug=False): return value.to(unit).magnitude elif dimensionless and isinstance(value, _handled_types): return value - + # Catch unhandled case - raise NotImplementedError(f"Could not process input for {key} with type={type(value)}.") + raise NotImplementedError( + f"Could not process input for {key} with type={type(value)}." + ) + def _set_units(value, unit, key): """Set the units of "value" to "unit" (or leave as is if unit=None).""" if not default_registry.force_ndarray_like and not default_registry.force_ndarray: - raise ValueError("Set 'force_ndarray_like' or 'force_ndarray' when defining default_registry=UnitRegistry(...).") + raise ValueError( + "Set 'force_ndarray_like' or 'force_ndarray' when defining default_registry=UnitRegistry(...)." + ) if unit is None: return value - + dimensionless = (unit == "") or (unit == default_registry.dimensionless) if dimensionless and isinstance(value, _handled_types): return value - + if isinstance(value, xr.DataArray): if value.pint.units is None: return value.pint.quantify(unit, unit_registry=default_registry) else: return value.pint.to(unit) - + elif isinstance(value, Quantity): return value.to(unit) - + elif isinstance(value, _handled_types): return Quantity(value, unit) - + # Catch unhandled case - raise NotImplementedError(f"Could not process output with type={type(value)} for output {key}.") + raise NotImplementedError( + f"Could not process output with type={type(value)} for output {key}." + ) + def _get_default_values(func): """Get default values for arguments.""" @@ -87,7 +104,15 @@ def _get_default_values(func): if v.default is not Parameter.empty } -def wraps_ufunc(_func=None, *, return_units: dict[str, str], input_units: dict[str, str], ufunc_kwargs_from_wrapper=dict(), auto_none: dict=False): + +def wraps_ufunc( + _func=None, + *, + return_units: dict[str, str], + input_units: dict[str, str], + ufunc_kwargs_from_wrapper=dict(), + auto_none: dict = False, +): """Wraps a function to accept pint-xarray inputs and then uses xarrays apply_ufunc to iterate over dimensions. You can pass arguments to apply_ufunc (see https://docs.xarray.dev/en/stable/generated/xarray.apply_ufunc.html) @@ -105,16 +130,17 @@ def wraps_ufunc(_func=None, *, return_units: dict[str, str], input_units: dict[s _check_wrapper_args(input_units) _check_wrapper_args(return_units) - if "vectorize" not in ufunc_kwargs_from_wrapper.keys(): # Unless explicitly set to False, assume that the user wants vectorize=True. ufunc_kwargs_from_wrapper["vectorize"] = True if "output_core_dims" not in ufunc_kwargs_from_wrapper.keys(): # Unless explicitly set, assume that the user doesn't want to add new dimensions to the output. - ufunc_kwargs_from_wrapper["output_core_dims"] = [() for i in range(len(return_units.keys()))] - + ufunc_kwargs_from_wrapper["output_core_dims"] = [ + () for i in range(len(return_units.keys())) + ] + def decorator(func): - + # Work out what the function is expecting as inputs func_parameters = signature(func).parameters @@ -123,22 +149,23 @@ def decorator(func): if not parameter in input_units.keys(): # Only need to include units which you want to perform unit-checking on. input_units[parameter] = None - + # Make sure that there are as many units in "input_units" as there are arguments count_params = len(func_parameters) if len(input_units) != count_params: - raise TypeError(f"{func.__name__} takes {count_params} parameters, but {len(input_units)} units were passed") - + raise TypeError( + f"{func.__name__} takes {count_params} parameters, but {len(input_units)} units were passed" + ) + for parameter in input_units.keys(): if not parameter in func_parameters: - raise TypeError(f"{func.__name__} does not have a parameter {parameter}") + raise TypeError( + f"{func.__name__} does not have a parameter {parameter}" + ) @functools.wraps(func) - def wrapper(*positional_args, - ufunc_kwargs=dict(), - debug=False, - **keyword_args): - + def wrapper(*positional_args, ufunc_kwargs=dict(), debug=False, **keyword_args): + # Should we copy kwargs? Safer since no global modification, but # also requires more memory. @@ -151,7 +178,8 @@ def wrapper(*positional_args, # Add default values if they aren't already set keyword_args = {**_get_default_values(func), **keyword_args} - if debug: print(f"{func} called with {keyword_args}") + if debug: + print(f"{func} called with {keyword_args}") # Convert the input into the desired units for key in func_parameters: @@ -159,7 +187,7 @@ def wrapper(*positional_args, unit = input_units[key] keyword_args[key] = _convert_units(key, value, unit, debug=debug) - + ufunc_kwargs = {**ufunc_kwargs_from_wrapper, **ufunc_kwargs} with warnings.catch_warnings(): @@ -167,39 +195,57 @@ def wrapper(*positional_args, # this is what the pint.registry_helpers.wraps decorator does. warnings.simplefilter("ignore", category=UnitStrippedWarning) - if debug: print(f"apply_ufunc called with {keyword_args}, with kwargs {ufunc_kwargs}") - + if debug: + print( + f"apply_ufunc called with {keyword_args}, with kwargs {ufunc_kwargs}" + ) + function_return = xr.apply_ufunc( func, *[keyword_args[param] for param in func_parameters], - **ufunc_kwargs + **ufunc_kwargs, ) - if debug: print(f"apply_ufunc returned {function_return}") + if debug: + print(f"apply_ufunc returned {function_return}") # Convert the output into the desired units if isinstance(function_return, tuple): # Multiple return if len(return_units) != len(function_return): - raise TypeError(f"{func.__name__} returned {len(function_return)} values(s), but {len(return_units)} units were passed") + raise TypeError( + f"{func.__name__} returned {len(function_return)} values(s), but {len(return_units)} units were passed" + ) new_function_return = list(function_return) - - for i, (returned, unit, key) in enumerate(zip(function_return, return_units.values(), return_units.keys())): + + for i, (returned, unit, key) in enumerate( + zip(function_return, return_units.values(), return_units.keys()) + ): new_function_return[i] = _set_units(returned, unit, key) - + return tuple(new_function_return) else: - if len(return_units) == 0 and isinstance(function_return, xr.DataArray) and np.all(function_return == None): + if ( + len(return_units) == 0 + and isinstance(function_return, xr.DataArray) + and np.all(function_return == None) + ): return None elif len(return_units) != 1: - raise TypeError(f"{func.__name__} returned {len(function_return)} value(s), but {len(return_units)} units were passed") - - return _set_units(function_return, unit=list(return_units.values())[0], key=list(return_units.keys())[0]) - + raise TypeError( + f"{func.__name__} returned {len(function_return)} value(s), but {len(return_units)} units were passed" + ) + + return _set_units( + function_return, + unit=list(return_units.values())[0], + key=list(return_units.keys())[0], + ) + return wrapper - + if _func is None: return decorator else: