Skip to content

Commit

Permalink
use .sizes instead of .dims to avoid warning and future deprecation
Browse files Browse the repository at this point in the history
  • Loading branch information
timothyas committed Jun 21, 2024
1 parent 8debd72 commit 1ea79ba
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions graphcast/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _validate_targets_and_forcings(self, targets, forcings):
f'forcings, which isn\'t allowed: {overlap}')

def _update_inputs(self, inputs, next_frame):
num_inputs = inputs.dims['time']
num_inputs = inputs.sizes['time']

predicted_or_forced_inputs = next_frame[list(inputs.keys())]

Expand Down Expand Up @@ -199,7 +199,7 @@ def one_step_prediction(inputs, scan_variables):
return next_inputs, flat_pred

if self._gradient_checkpointing:
scan_length = targets_template.dims['time']
scan_length = targets_template.sizes['time']
if scan_length <= 1:
logging.warning(
'Skipping gradient checkpointing for sequence length of 1')
Expand Down
4 changes: 2 additions & 2 deletions graphcast/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def chunked_prediction_generator(
if "datetime" in forcings.coords:
del forcings.coords["datetime"]

num_target_steps = targets_template.dims["time"]
num_target_steps = targets_template.sizes["time"]
num_chunks, remainder = divmod(num_target_steps, num_steps_per_chunk)
if remainder != 0:
raise ValueError(
Expand Down Expand Up @@ -202,7 +202,7 @@ def _get_next_inputs(
next_inputs = next_frame[next_inputs_keys]

# Apply concatenate next frame with inputs, crop what we don't need.
num_inputs = prev_inputs.dims["time"]
num_inputs = prev_inputs.sizes["time"]
return (
xarray.concat(
[prev_inputs, next_inputs], dim="time", data_vars="different")
Expand Down

0 comments on commit 1ea79ba

Please sign in to comment.