Skip to content

Commit

Permalink
Fix E721 (pycodestyle)
Browse files Browse the repository at this point in the history
  • Loading branch information
keisen committed Oct 6, 2023
1 parent f2c65a5 commit de02308
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test__call__if_seed_input_is_(self, seed_input, expected, conv_model):
activation_maximization = ActivationMaximization(conv_model)
result = activation_maximization(CategoricalScore(0), seed_input=seed_input)
if type(expected) is list:
assert type(result) == list
assert type(result) is list
result = result[0]
expected = expected[0]
assert result.shape == expected
Expand Down Expand Up @@ -527,7 +527,7 @@ def test__call__if_seed_input_is_(self, seed_input, expected, multiple_outputs_m
result = activation_maximization(
[CategoricalScore(1), BinaryScore(False)], seed_input=seed_input)
if type(expected) is list:
assert type(result) == list
assert type(result) is list
result = result[0]
expected = expected[0]
assert result.shape == expected
Expand Down Expand Up @@ -876,7 +876,7 @@ def test__call__if_seed_input_is_(self, seed_input, expected, dense_model):
input_modifiers=None,
regularizers=None)
if type(expected) is list:
assert type(result) == list
assert type(result) is list
result = result[0]
expected = expected[0]
assert result.shape == expected
Expand Down
4 changes: 2 additions & 2 deletions tf_keras_vis/activation_maximization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def _get_input_ranges(self, input_range):
"The length of input range tuple must be 2 (Or it is just `None`, not tuple), "
f"but you passed {r} as `input_ranges[{i}]`.")
a, b = r
if None not in r and type(a) != type(b):
if None not in r and type(a) is not type(b):
raise TypeError(
"The type of low and high values in the input range must be the same, "
f"but you passed {r} are {type(a)} and {type(b)} ")
Expand Down Expand Up @@ -408,7 +408,7 @@ def _get_callables_to_apply_to_each_input(self, callables, object_name):
callables = ((k, listify(v)) for k, v in callables.items())
else:
callables = listify(callables)
if len(callables) == 0 or len(list(filter(lambda x: type(x) == list, callables))) == 0:
if len(callables) == 0 or len(list(filter(lambda x: type(x) is list, callables))) == 0:
callables = [callables]
if len(callables) <= len(keys):
callables = (listify(value_each_input) for value_each_input in callables)
Expand Down

0 comments on commit de02308

Please sign in to comment.