diff --git a/examples/smart_contracts/assert_sum.py b/examples/smart_contracts/assert_sum.py index a36d6cb0..4adf34e7 100644 --- a/examples/smart_contracts/assert_sum.py +++ b/examples/smart_contracts/assert_sum.py @@ -3,6 +3,9 @@ def validator(datum: int, redeemer: int, context: ScriptContext) -> None: + purpose = context.purpose + if not isinstance(purpose, Spending): + print(f"Wrong script purpose: {purpose}") assert ( datum + redeemer == 42 ), f"Expected datum and redeemer to sum to 42, but they sum to {datum + redeemer}" diff --git a/opshin/optimize/optimize_const_folding.py b/opshin/optimize/optimize_const_folding.py index 6191c964..db91cf39 100644 --- a/opshin/optimize/optimize_const_folding.py +++ b/opshin/optimize/optimize_const_folding.py @@ -3,6 +3,7 @@ import logging from ast import * +from ordered_set import OrderedSet from pycardano import PlutusData @@ -98,7 +99,7 @@ class ShallowNameDefCollector(CompilingNodeVisitor): step = "Collecting occuring variable names" def __init__(self): - self.vars = set() + self.vars = OrderedSet() def visit_Name(self, node: Name) -> None: if isinstance(node.ctx, Store): @@ -172,13 +173,13 @@ class OptimizeConstantFolding(CompilingNodeTransformer): def __init__(self): self.scopes_visible = [ - set(INITIAL_SCOPE.keys()).difference(SAFE_GLOBALS.keys()) + OrderedSet(INITIAL_SCOPE.keys()).difference(SAFE_GLOBALS.keys()) ] self.scopes_constants = [dict()] - self.constants = set() + self.constants = OrderedSet() def enter_scope(self): - self.scopes_visible.append(set()) + self.scopes_visible.append(OrderedSet()) self.scopes_constants.append(dict()) def add_var_visible(self, var: str): @@ -191,7 +192,7 @@ def add_constant(self, var: str, value: typing.Any): self.scopes_constants[-1][var] = value def visible_vars(self): - res_set = set() + res_set = OrderedSet() for s in self.scopes_visible: res_set.update(s) return res_set diff --git a/opshin/optimize/optimize_remove_deadvars.py b/opshin/optimize/optimize_remove_deadvars.py index 75146bfb..e52f3873 100644 --- a/opshin/optimize/optimize_remove_deadvars.py +++ b/opshin/optimize/optimize_remove_deadvars.py @@ -2,6 +2,8 @@ from copy import copy from collections import defaultdict +from ordered_set import OrderedSet + from ..util import CompilingNodeVisitor, CompilingNodeTransformer from ..type_inference import INITIAL_SCOPE from ..typed_ast import TypedAnnAssign @@ -93,7 +95,7 @@ def visit_Module(self, node: Module) -> Module: # collect all variable names collector = NameLoadCollector() collector.visit(node_cp) - loaded_vars = set(collector.loaded.keys()) | {"validator_0"} + loaded_vars = OrderedSet(collector.loaded.keys()) | {"validator_0"} # break if the set of loaded vars did not change -> set of vars to remove does also not change if loaded_vars == self.loaded_vars: break @@ -115,7 +117,7 @@ def visit_If(self, node: If): scope_orelse_cp = self.guaranteed_avail_names[-1].copy() self.exit_scope() # what remains after this in the scope is the intersection of both - for var in set(scope_body_cp).intersection(scope_orelse_cp): + for var in OrderedSet(scope_body_cp).intersection(scope_orelse_cp): self.set_guaranteed(var) return node_cp diff --git a/opshin/rewrite/rewrite_import.py b/opshin/rewrite/rewrite_import.py index 5bdad500..a1fea742 100644 --- a/opshin/rewrite/rewrite_import.py +++ b/opshin/rewrite/rewrite_import.py @@ -6,6 +6,7 @@ import typing import sys from ast import * +from ordered_set import OrderedSet from ..util import CompilingNodeTransformer @@ -57,7 +58,7 @@ class RewriteImport(CompilingNodeTransformer): def __init__(self, filename=None, package=None, resolved_imports=None): self.filename = filename self.package = package - self.resolved_imports = resolved_imports or set() + self.resolved_imports = resolved_imports or OrderedSet() def visit_ImportFrom( self, node: ImportFrom diff --git a/opshin/rewrite/rewrite_scoping.py b/opshin/rewrite/rewrite_scoping.py index 60c695cf..a665e106 100644 --- a/opshin/rewrite/rewrite_scoping.py +++ b/opshin/rewrite/rewrite_scoping.py @@ -2,6 +2,8 @@ from copy import copy from collections import defaultdict +from ordered_set import OrderedSet + from ..type_inference import INITIAL_SCOPE, PolymorphicFunctionInstanceType from ..util import CompilingNodeTransformer, CompilingNodeVisitor @@ -14,7 +16,7 @@ class ShallowNameDefCollector(CompilingNodeVisitor): step = "Collecting occuring variable names" def __init__(self): - self.vars = set() + self.vars = OrderedSet() def visit_Name(self, node: Name) -> None: if isinstance(node.ctx, Store) or isinstance( @@ -36,7 +38,7 @@ class RewriteScoping(CompilingNodeTransformer): def __init__(self): self.latest_scope_id = 0 - self.scopes = [(set(INITIAL_SCOPE.keys()), -1)] + self.scopes = [(OrderedSet(INITIAL_SCOPE.keys()), -1)] def variable_scope_id(self, name: str) -> int: """find the id of the scope in which this variable is defined (closest to its usage)""" @@ -49,7 +51,7 @@ def variable_scope_id(self, name: str) -> int: ) def enter_scope(self): - self.scopes.append((set(), self.latest_scope_id)) + self.scopes.append((OrderedSet(), self.latest_scope_id)) self.latest_scope_id += 1 def exit_scope(self): diff --git a/opshin/tests/test_misc.py b/opshin/tests/test_misc.py index 7f8bde81..f361d001 100644 --- a/opshin/tests/test_misc.py +++ b/opshin/tests/test_misc.py @@ -60,7 +60,16 @@ def test_assert_sum_contract_succeed(self): input_file = "examples/smart_contracts/assert_sum.py" with open(input_file) as fp: source_code = fp.read() - ret = eval_uplc(source_code, 20, 22, Unit()) + ret = eval_uplc( + source_code, + 20, + 22, + uplc.data_from_cbor( + bytes.fromhex( + "d8799fd8799f9fd8799fd8799fd8799f582055d353acacaab6460b37ed0f0e3a1a0aabf056df4a7fa1e265d21149ccacc527ff01ffd8799fd8799fd87a9f581cdbe769758f26efb21f008dc097bb194cffc622acc37fcefc5372eee3ffd87a80ffa140a1401a00989680d87a9f5820dfab81872ce2bbe6ee5af9bbfee4047f91c1f57db5e30da727d5fef1e7f02f4dffd87a80ffffff809fd8799fd8799fd8799f581cdc315c289fee4484eda07038393f21dc4e572aff292d7926018725c2ffd87a80ffa140a14000d87980d87a80ffffa140a14000a140a1400080a0d8799fd8799fd87980d87a80ffd8799fd87b80d87a80ffff80a1d87a9fd8799fd8799f582055d353acacaab6460b37ed0f0e3a1a0aabf056df4a7fa1e265d21149ccacc527ff01ffffd87980a15820dfab81872ce2bbe6ee5af9bbfee4047f91c1f57db5e30da727d5fef1e7f02f4dd8799f581cdc315c289fee4484eda07038393f21dc4e572aff292d7926018725c2ffd8799f5820746957f0eb57f2b11119684e611a98f373afea93473fefbb7632d579af2f6259ffffd87a9fd8799fd8799f582055d353acacaab6460b37ed0f0e3a1a0aabf056df4a7fa1e265d21149ccacc527ff01ffffff" + ) + ), + ) self.assertEqual(ret, uplc.PlutusConstr(0, [])) @unittest.expectedFailure diff --git a/opshin/type_inference.py b/opshin/type_inference.py index feea6778..6ec075ed 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -13,6 +13,7 @@ """ import typing from collections import defaultdict +from ordered_set import OrderedSet from copy import copy from pycardano import PlutusData @@ -142,7 +143,7 @@ def constant_type(c): def union_types(*ts: Type): - ts = list(set(ts)) + ts = OrderedSet(ts) if len(ts) == 1: return ts[0] assert ts, "Union must combine multiple classes" @@ -151,7 +152,7 @@ def union_types(*ts: Type): isinstance(e, UnionType) and all(isinstance(e2, RecordType) for e2 in e.typs) for e in ts ), "Union must combine multiple PlutusData classes" - union_set = set() + union_set = OrderedSet() for t in ts: union_set.update(t.typs) assert distinct( @@ -161,12 +162,12 @@ def union_types(*ts: Type): def intersection_types(*ts: Type): - ts = list(set(ts)) + ts = OrderedSet(ts) if len(ts) == 1: return ts[0] ts = [t if isinstance(t, UnionType) else UnionType(frozenlist([t])) for t in ts] assert ts, "Must have at least one type to intersect" - intersection_set = set(ts[0].typs) + intersection_set = OrderedSet(ts[0].typs) for t in ts[1:]: intersection_set.intersection_update(t.typs) return UnionType(frozenlist(intersection_set)) @@ -261,7 +262,7 @@ def visit_UnaryOp(self, node: UnaryOp) -> PairType: def merge_scope(s1: typing.Dict[str, Type], s2: typing.Dict[str, Type]): - keys = set(s1.keys()).union(s2.keys()) + keys = OrderedSet(s1.keys()).union(s2.keys()) merged = {} for k in keys: if k not in s1.keys(): diff --git a/opshin/types.py b/opshin/types.py index 4bd0fe91..8d9f0119 100644 --- a/opshin/types.py +++ b/opshin/types.py @@ -2,6 +2,7 @@ from ast import * import itertools +from ordered_set import OrderedSet import uplc.ast @@ -535,7 +536,7 @@ def attribute_type(self, attr) -> "Type": return IntegerInstanceType # need to have a common field with the same name if all(attr in (n for n, t in x.record.fields) for x in self.typs): - attr_types = set( + attr_types = OrderedSet( t for x in self.typs for n, t in x.record.fields if n == attr ) for at in attr_types: diff --git a/poetry.lock b/poetry.lock index 455b9b3b..3bdd43e1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1199,6 +1199,20 @@ files = [ [package.dependencies] setuptools = "*" +[[package]] +name = "ordered-set" +version = "4.1.0" +description = "An OrderedSet is a custom MutableSet that remembers its order, so that every" +optional = false +python-versions = ">=3.7" +files = [ + {file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"}, + {file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"}, +] + +[package.extras] +dev = ["black", "mypy", "pytest"] + [[package]] name = "oscrypto" version = "1.3.0" @@ -2206,4 +2220,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8, <3.12" -content-hash = "41d7679c404c68c497a7b88550933d636e885be62ffcd0a0d1b0ed5e1710d319" +content-hash = "5141133b0028d640131de080176a0cfa69e53a7d1c37b52d4ac979c1201da3dd" diff --git a/pyproject.toml b/pyproject.toml index bbd1d79a..5dbb54f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ pluthon = "^0.4.1" pycardano = "^0.9.0" frozenlist2 = "^1.0.0" astunparse = {version = "^1.6.3", python = "<3.9"} +ordered-set = "^4.1.0" [tool.poetry.group.dev.dependencies]