From de023082713d928dd46abe7cee6057a6a3eec01f Mon Sep 17 00:00:00 2001 From: keisen Date: Sat, 7 Oct 2023 00:24:03 +0900 Subject: [PATCH] Fix E721 (pycodestyle) --- .../activation_maximization/activation_maximization_test.py | 6 +++--- tf_keras_vis/activation_maximization/__init__.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/tf_keras_vis/activation_maximization/activation_maximization_test.py b/tests/tf_keras_vis/activation_maximization/activation_maximization_test.py index 7ba6c60..b13c909 100644 --- a/tests/tf_keras_vis/activation_maximization/activation_maximization_test.py +++ b/tests/tf_keras_vis/activation_maximization/activation_maximization_test.py @@ -51,7 +51,7 @@ def test__call__if_seed_input_is_(self, seed_input, expected, conv_model): activation_maximization = ActivationMaximization(conv_model) result = activation_maximization(CategoricalScore(0), seed_input=seed_input) if type(expected) is list: - assert type(result) == list + assert type(result) is list result = result[0] expected = expected[0] assert result.shape == expected @@ -527,7 +527,7 @@ def test__call__if_seed_input_is_(self, seed_input, expected, multiple_outputs_m result = activation_maximization( [CategoricalScore(1), BinaryScore(False)], seed_input=seed_input) if type(expected) is list: - assert type(result) == list + assert type(result) is list result = result[0] expected = expected[0] assert result.shape == expected @@ -876,7 +876,7 @@ def test__call__if_seed_input_is_(self, seed_input, expected, dense_model): input_modifiers=None, regularizers=None) if type(expected) is list: - assert type(result) == list + assert type(result) is list result = result[0] expected = expected[0] assert result.shape == expected diff --git a/tf_keras_vis/activation_maximization/__init__.py b/tf_keras_vis/activation_maximization/__init__.py index 3c5edfd..75d93be 100644 --- a/tf_keras_vis/activation_maximization/__init__.py +++ b/tf_keras_vis/activation_maximization/__init__.py @@ -299,7 +299,7 @@ def _get_input_ranges(self, input_range): "The length of input range tuple must be 2 (Or it is just `None`, not tuple), " f"but you passed {r} as `input_ranges[{i}]`.") a, b = r - if None not in r and type(a) != type(b): + if None not in r and type(a) is not type(b): raise TypeError( "The type of low and high values in the input range must be the same, " f"but you passed {r} are {type(a)} and {type(b)} ") @@ -408,7 +408,7 @@ def _get_callables_to_apply_to_each_input(self, callables, object_name): callables = ((k, listify(v)) for k, v in callables.items()) else: callables = listify(callables) - if len(callables) == 0 or len(list(filter(lambda x: type(x) == list, callables))) == 0: + if len(callables) == 0 or len(list(filter(lambda x: type(x) is list, callables))) == 0: callables = [callables] if len(callables) <= len(keys): callables = (listify(value_each_input) for value_each_input in callables)