diff --git a/opshin/compiler.py b/opshin/compiler.py index ac8b0e27..1cf0b035 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -393,6 +393,8 @@ def visit_Name(self, node: TypedName) -> plt.AST: if isinstance(node.typ, ClassType): # if this is not an instance but a class, call the constructor return node.typ.constr() + if hasattr(node, "is_wrapped") and node.is_wrapped: + return transform_ext_params_map(node.typ)(plt.Force(plt.Var(node.id))) return plt.Force(plt.Var(node.id)) def visit_Expr(self, node: TypedExpr) -> CallAST: @@ -433,7 +435,7 @@ def visit_Call(self, node: TypedCall) -> plt.AST: assert isinstance(t, InstanceType) # pass in all arguments evaluated with the statemonad a_int = self.visit(a) - if isinstance(t.typ, AnyType): + if isinstance(t.typ, AnyType) or isinstance(t.typ, UnionType): # if the function expects input of generic type data, wrap data before passing it inside a_int = transform_output_map(a.typ)(a_int) args.append(a_int) @@ -914,6 +916,14 @@ def visit_Dict(self, node: TypedDict) -> plt.AST: return l def visit_IfExp(self, node: TypedIfExp) -> plt.AST: + if isinstance(node.typ.typ, UnionType): + body = self.visit(node.body) + orelse = self.visit(node.orelse) + if not isinstance(node.body.typ, UnionType): + body = transform_output_map(node.body.typ)(body) + if not isinstance(node.orelse.typ, UnionType): + orelse = transform_output_map(node.orelse.typ)(orelse) + return plt.Ite(self.visit(node.test), body, orelse) return plt.Ite( self.visit(node.test), self.visit(node.body), diff --git a/opshin/fun_impls.py b/opshin/fun_impls.py index 7a65dee8..ea19bb15 100644 --- a/opshin/fun_impls.py +++ b/opshin/fun_impls.py @@ -95,6 +95,12 @@ def type_from_args(self, args: typing.List[Type]) -> FunctionType: return FunctionType(args, BoolInstanceType) def impl_from_args(self, args: typing.List[Type]) -> plt.AST: + if not (isinstance(args[0], UnionType) or isinstance(args[0].typ, UnionType)): + if args[0].typ == args[1]: + return OLambda(["x"], plt.Bool(True)) + else: + return OLambda(["x"], plt.Bool(False)) + if isinstance(args[1], IntegerType): return OLambda( ["x"], diff --git a/opshin/rewrite/rewrite_forbidden_overwrites.py b/opshin/rewrite/rewrite_forbidden_overwrites.py index 488e1a22..72347eeb 100644 --- a/opshin/rewrite/rewrite_forbidden_overwrites.py +++ b/opshin/rewrite/rewrite_forbidden_overwrites.py @@ -13,6 +13,7 @@ "List", "Dict", "Union", + "Self", # decorator and class name "dataclass", "PlutusData", diff --git a/opshin/rewrite/rewrite_import_typing.py b/opshin/rewrite/rewrite_import_typing.py index 6f58414f..a906e7aa 100644 --- a/opshin/rewrite/rewrite_import_typing.py +++ b/opshin/rewrite/rewrite_import_typing.py @@ -49,6 +49,16 @@ def visit_ClassDef(self, node: ClassDef) -> ClassDef: and arg.annotation.id == "Self" ): node.body[i].args.args[j].annotation.idSelf = node.name + if ( + isinstance(arg.annotation, Subscript) + and arg.annotation.value.id == "Union" + ): + for k, s in enumerate(arg.annotation.slice.elts): + if isinstance(s, Name) and s.id == "Self": + node.body[i].args.args[j].annotation.slice.elts[ + k + ].idSelf = node.name + if ( isinstance(attribute.returns, Name) and attribute.returns.id == "Self" diff --git a/opshin/rewrite/rewrite_scoping.py b/opshin/rewrite/rewrite_scoping.py index 560af9a3..ca401ad3 100644 --- a/opshin/rewrite/rewrite_scoping.py +++ b/opshin/rewrite/rewrite_scoping.py @@ -40,6 +40,7 @@ class RewriteScoping(CompilingNodeTransformer): step = "Rewrite all variables to inambiguously point to the definition in the nearest enclosing scope" latest_scope_id: int scopes: typing.List[typing.Tuple[OrderedSet, int]] + current_Self: typing.Tuple[str, str] def variable_scope_id(self, name: str) -> int: """find the id of the scope in which this variable is defined (closest to its usage)""" @@ -86,6 +87,9 @@ def visit_Module(self, node: Module) -> Module: def visit_Name(self, node: Name) -> Name: nc = copy(node) # setting is handled in either enclosing module or function + if node.id == "Self": + assert node.idSelf == self.current_Self[1] + nc.idSelf_new = self.current_Self[0] nc.id = self.map_name(node.id) return nc @@ -93,6 +97,7 @@ def visit_ClassDef(self, node: ClassDef) -> ClassDef: cp_node = RecordScoper.scope(node, self) for i, attribute in enumerate(cp_node.body): if isinstance(attribute, FunctionDef): + self.current_Self = (cp_node.name, cp_node.orig_name) cp_node.body[i] = self.visit_FunctionDef(attribute, method=True) return cp_node @@ -108,17 +113,9 @@ def visit_FunctionDef(self, node: FunctionDef, method: bool = False) -> Function a_cp = copy(a) self.set_variable_scope(a.arg) a_cp.arg = self.map_name(a.arg) - a_cp.annotation = ( - self.visit(a.annotation) - if not hasattr(a.annotation, "idSelf") - else a.annotation - ) + a_cp.annotation = self.visit(a.annotation) node_cp.args.args.append(a_cp) - node_cp.returns = ( - self.visit(node.returns) - if not hasattr(node.returns, "idSelf") - else node.returns - ) + node_cp.returns = self.visit(node.returns) # vars defined in this scope shallow_node_def_collector = ShallowNameDefCollector() for s in node.body: diff --git a/opshin/std/fractions.py b/opshin/std/fractions.py index caad5f83..9c96e6bd 100644 --- a/opshin/std/fractions.py +++ b/opshin/std/fractions.py @@ -17,33 +17,165 @@ class Fraction(PlutusData): numerator: int denominator: int - -def add_fraction(a: Fraction, b: Fraction) -> Fraction: - """returns a + b""" - return Fraction( - (a.numerator * b.denominator) + (b.numerator * a.denominator), - a.denominator * b.denominator, - ) - - -def neg_fraction(a: Fraction) -> Fraction: - """returns -a""" - return Fraction(-a.numerator, a.denominator) - - -def sub_fraction(a: Fraction, b: Fraction) -> Fraction: - """returns a - b""" - return add_fraction(a, neg_fraction(b)) - - -def mul_fraction(a: Fraction, b: Fraction) -> Fraction: - """returns a * b""" - return Fraction(a.numerator * b.numerator, a.denominator * b.denominator) - - -def div_fraction(a: Fraction, b: Fraction) -> Fraction: - """returns a / b""" - return Fraction(a.numerator * b.denominator, a.denominator * b.numerator) + def norm(self) -> "Fraction": + """Restores the invariant that num/denom are in the smallest possible denomination and denominator > 0""" + return _norm_gcd_fraction(_norm_signs_fraction(self)) + + def ceil(self) -> int: + return ( + self.numerator + self.denominator - sign(self.denominator) + ) // self.denominator + + def __add__(self, other: Union["Fraction", int]) -> "Fraction": + """returns self + other""" + if isinstance(other, Fraction): + return Fraction( + (self.numerator * other.denominator) + + (other.numerator * self.denominator), + self.denominator * other.denominator, + ) + else: + return Fraction( + (self.numerator) + (other * self.denominator), + self.denominator, + ) + + def __neg__( + self, + ) -> "Fraction": + """returns -self""" + return Fraction(-self.numerator, self.denominator) + + def __sub__(self, other: Union["Fraction", int]) -> "Fraction": + """returns self - other""" + if isinstance(other, Fraction): + return Fraction( + (self.numerator * other.denominator) + - (other.numerator * self.denominator), + self.denominator * other.denominator, + ) + else: + return Fraction( + self.numerator - (other * self.denominator), self.denominator + ) + + def __mul__(self, other: Union["Fraction", int]) -> "Fraction": + """returns self * other""" + if isinstance(other, Fraction): + return Fraction( + self.numerator * other.numerator, self.denominator * other.denominator + ) + else: + return Fraction(self.numerator * other, self.denominator) + + def __truediv__(self, other: Union["Fraction", int]) -> "Fraction": + """returns self / other""" + if isinstance(other, Fraction): + return Fraction( + self.numerator * other.denominator, self.denominator * other.numerator + ) + else: + return Fraction(self.numerator, self.denominator * other) + + def __ge__(self, other: Union["Fraction", int]) -> bool: + """returns self >= other""" + if isinstance(other, Fraction): + if self.denominator * other.denominator >= 0: + res = ( + self.numerator * other.denominator + >= self.denominator * other.numerator + ) + else: + res = ( + self.numerator * other.denominator + <= self.denominator * other.numerator + ) + return res + else: + if self.denominator >= 0: + res = self.numerator >= self.denominator * other + else: + res = self.numerator <= self.denominator * other + return res + + def __le__(self, other: Union["Fraction", int]) -> bool: + """returns self <= other""" + if isinstance(other, Fraction): + if self.denominator * other.denominator >= 0: + res = ( + self.numerator * other.denominator + <= self.denominator * other.numerator + ) + else: + res = ( + self.numerator * other.denominator + >= self.denominator * other.numerator + ) + return res + else: + if self.denominator >= 0: + res = self.numerator <= self.denominator * other + else: + res = self.numerator >= self.denominator * other + return res + + def __eq__(self, other: Union["Fraction", int]) -> bool: + """returns self == other""" + if isinstance(other, Fraction): + return ( + self.numerator * other.denominator == self.denominator * other.numerator + ) + else: + return self.numerator == self.denominator * other + + def __lt__(self, other: Union["Fraction", int]) -> bool: + """returns self < other""" + if isinstance(other, Fraction): + if self.denominator * other.denominator >= 0: + res = ( + self.numerator * other.denominator + < self.denominator * other.numerator + ) + else: + res = ( + self.numerator * other.denominator + > self.denominator * other.numerator + ) + return res + else: + if self.denominator >= 0: + res = self.numerator < self.denominator * other + else: + res = self.numerator > self.denominator * other + return res + + def __gt__(self, other: Union["Fraction", int]) -> bool: + """returns self > other""" + if isinstance(other, Fraction): + if self.denominator * other.denominator >= 0: + res = ( + self.numerator * other.denominator + > self.denominator * other.numerator + ) + else: + res = ( + self.numerator * other.denominator + < self.denominator * other.numerator + ) + return res + else: + if self.denominator >= 0: + res = self.numerator > self.denominator * other + else: + res = self.numerator < self.denominator * other + return res + + def __floordiv__(self, other: Union["Fraction", int]) -> int: + if isinstance(other, Fraction): + x = self / other + return x.numerator // x.denominator + else: + return self.numerator // (other * self.denominator) def _norm_signs_fraction(a: Fraction) -> Fraction: @@ -62,50 +194,5 @@ def norm_fraction(a: Fraction) -> Fraction: return _norm_gcd_fraction(_norm_signs_fraction(a)) -def ge_fraction(a: Fraction, b: Fraction) -> bool: - """returns a >= b""" - if a.denominator * b.denominator >= 0: - res = a.numerator * b.denominator >= a.denominator * b.numerator - else: - res = a.numerator * b.denominator <= a.denominator * b.numerator - return res - - -def le_fraction(a: Fraction, b: Fraction) -> bool: - """returns a <= b""" - if a.denominator * b.denominator >= 0: - res = a.numerator * b.denominator <= a.denominator * b.numerator - else: - res = a.numerator * b.denominator >= a.denominator * b.numerator - return res - - -def eq_fraction(a: Fraction, b: Fraction) -> bool: - """returns a == b""" - return a.numerator * b.denominator == a.denominator * b.numerator - - -def lt_fraction(a: Fraction, b: Fraction) -> bool: - """returns a < b""" - if a.denominator * b.denominator >= 0: - res = a.numerator * b.denominator < a.denominator * b.numerator - else: - res = a.numerator * b.denominator > a.denominator * b.numerator - return res - - -def gt_fraction(a: Fraction, b: Fraction) -> bool: - """returns a > b""" - if a.denominator * b.denominator >= 0: - res = a.numerator * b.denominator > a.denominator * b.numerator - else: - res = a.numerator * b.denominator < a.denominator * b.numerator - return res - - -def floor_fraction(a: Fraction) -> int: - return a.numerator // a.denominator - - def ceil_fraction(a: Fraction) -> int: return (a.numerator + a.denominator - sign(a.denominator)) // a.denominator diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 4c7c3d3d..e31c6df6 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -279,6 +279,7 @@ class AggressiveTypeInferencer(CompilingNodeTransformer): def __init__(self, allow_isinstance_anything=False): self.allow_isinstance_anything = allow_isinstance_anything self.FUNCTION_ARGUMENT_REGISTRY = {} + self.wrapped = [] # A stack of dictionaries for storing scoped knowledge of variable types self.scopes = [INITIAL_SCOPE] @@ -366,10 +367,22 @@ def type_from_annotation(self, ann: expr): if isinstance(ann, Constant): if ann.value is None: return UnitType() + else: + for scope in reversed(self.scopes): + for key, value in scope.items(): + if ( + isinstance(value, RecordType) + and value.record.orig_name == ann.value + ): + return value + if isinstance(ann, Name): if ann.id in ATOMIC_TYPES: return ATOMIC_TYPES[ann.id] - v_t = self.variable_type(ann.id) + if ann.id == "Self": + v_t = self.variable_type(ann.idSelf_new) + else: + v_t = self.variable_type(ann.id) if isinstance(v_t, ClassType): return v_t raise TypeInferenceError( @@ -465,11 +478,23 @@ def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST: ), f"The following Dunder methods are supported {list(DUNDER_MAP.values())}. Received {func.name} which is not supported" func.name = f"{n.name}_{attribute.name}" for arg in func.args.args: - assert ( - arg.annotation is None or arg.annotation.id != n.name - ), "Invalid Python, class name is undefined at this stage." - assert ( - func.returns is None or func.returns.id != n.name + if not arg.annotation is None: + if isinstance(arg.annotation, ast.Name): + assert ( + arg.annotation is None or arg.annotation.id != n.name + ), "Invalid Python, class name is undefined at this stage." + elif ( + isinstance(arg.annotation, ast.Subscript) + and arg.annotation.value.id == "Union" + ): + for s in arg.annotation.slice.elts: + assert ( + isinstance(s, Name) and s.id != n.name + ) or isinstance( + s, Constant + ), "Invalid Python, class name is undefined at this stage." + assert isinstance(func.returns, Constant) or ( + isinstance(func.returns, Name) and func.returns.id != n.name ), "Invalid Python, class name is undefined at this stage" ann = ast.Name(id=n.name, ctx=ast.Load()) custom_fix_missing_locations(ann, attribute.args.args[0]) @@ -612,15 +637,20 @@ def visit_If(self, node: If) -> TypedIf: ).visit(typed_if.test) # for the time of the branch, these types are cast initial_scope = copy(self.scopes[-1]) - self.implement_typechecks(typchecks) + wrapped = self.implement_typechecks(typchecks) + self.wrapped.extend(wrapped.keys()) typed_if.body = self.visit_sequence(node.body) + self.wrapped = [x for x in self.wrapped if x not in wrapped.keys()] + # save resulting types final_scope_body = copy(self.scopes[-1]) # reverse typechecks and remove typing of one branch self.scopes[-1] = initial_scope # for the time of the else branch, the inverse types hold - self.implement_typechecks(inv_typchecks) + wrapped = self.implement_typechecks(inv_typchecks) + self.wrapped.extend(wrapped.keys()) typed_if.orelse = self.visit_sequence(node.orelse) + self.wrapped = [x for x in self.wrapped if x not in wrapped.keys()] final_scope_else = self.scopes[-1] # unify the resulting branch scopes self.scopes[-1] = merge_scope(final_scope_body, final_scope_else) @@ -689,6 +719,8 @@ def visit_Name(self, node: Name) -> TypedName: else: # Make sure that the rhs of an assign is evaluated first tn.typ = self.variable_type(node.id) + if node.id in self.wrapped: + tn.is_wrapped = True return tn def visit_keyword(self, node: keyword) -> Typedkeyword: @@ -851,7 +883,6 @@ def visit_Subscript(self, node: Subscript) -> TypedSubscript: "Dict", "List", ]: - ts.value = ts.typ = self.type_from_annotation(ts) return ts @@ -1119,10 +1150,15 @@ def visit_IfExp(self, node: IfExp) -> TypedIfExp: self.allow_isinstance_anything ).visit(node_cp.test) prevtyps = self.implement_typechecks(typchecks) + self.wrapped.extend(prevtyps.keys()) node_cp.body = self.visit(node.body) + self.wrapped = [x for x in self.wrapped if x not in prevtyps.keys()] + self.implement_typechecks(prevtyps) prevtyps = self.implement_typechecks(inv_typchecks) + self.wrapped.extend(prevtyps.keys()) node_cp.orelse = self.visit(node.orelse) + self.wrapped = [x for x in self.wrapped if x not in prevtyps.keys()] self.implement_typechecks(prevtyps) if node_cp.body.typ >= node_cp.orelse.typ: node_cp.typ = node_cp.body.typ diff --git a/tests/test_Unions.py b/tests/test_Unions.py index 30675b91..2aad1f2d 100644 --- a/tests/test_Unions.py +++ b/tests/test_Unions.py @@ -305,3 +305,182 @@ def validator(x: Union[int, bytes, bool]) -> int: with self.assertRaises(CompilerError) as ce: res = eval_uplc_value(source_code, True) self.assertIsInstance(ce.exception.orig_err, AssertionError) + + @hypothesis.given(st.sampled_from([14, b""])) + def test_Union_builtin_cast(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +def validator(x: Union[int,bytes]) -> int: + k: int = 0 + if isinstance(x, int): + k = x+5 + elif isinstance(x, bytes): + k = len(x) + return k +""" + res = eval_uplc_value(source_code, x) + real = x + 5 if isinstance(x, int) else len(x) + self.assertEqual(res, real) + + @hypothesis.given(st.sampled_from(range(14))) + def test_Union_builtin_cast_internal(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +def foo(x: Union[int,bytes]) -> int: + k: int = 0 + if isinstance(x, int): + k = x+5 + elif isinstance(x, bytes): + k = len(x) + return k + +def validator(x: int) -> int: + if x > 5: + k = foo(x+1) + else: + k = foo(b"0"*x) + return k +""" + res = eval_uplc_value(source_code, x) + real = x + 6 if x > 5 else len(b"0" * x) + self.assertEqual(res, real) + + @hypothesis.given(st.sampled_from(range(14))) + def test_Union_builtin_cast_direct(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +def validator(x: int) -> int: + y: Union[int,bytes] = 5 if x > 5 else b"0"*x + k: int = 0 + if isinstance(y, int): + k = y+1 + elif isinstance(y, bytes): + k = len(y) + return k +""" + res = eval_uplc_value(source_code, x) + real = 5 + 1 if x > 5 else len(b"0" * x) + self.assertEqual(res, real) + + @hypothesis.given(st.sampled_from(range(14))) + def test_Union_cast_ifexpr(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + x: int + +@dataclass() +class B(PlutusData): + CONSTR_ID = 1 + y: bytes + +def foo(x: Union[A, B]) -> int: + k: int = x.x + 1 if isinstance(x, A) else len(x.y) + return k + +def validator(x: int) -> int: + if x > 5: + k = foo(A(x)) + else: + k = foo(B(b"0"*x)) + return k +""" + res = eval_uplc_value(source_code, x) + real = x + 1 if x > 5 else len(b"0" * x) + self.assertEqual(res, real) + + @hypothesis.given(st.sampled_from(range(14))) + def test_Union_builtin_cast_ifexpr(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +def foo(x: Union[int, bytes]) -> int: + k: int = x + 1 if isinstance(x, int) else len(x) + return k + +def validator(x: int) -> int: + if x > 5: + k = foo(x+1) + else: + k = foo(b"0"*x) + return k +""" + res = eval_uplc_value(source_code, x) + real = x + 2 if x > 5 else len(b"0" * x) + self.assertEqual(res, real) + + @unittest.skip("Throw compilation error, hence not critical") + @hypothesis.given(st.sampled_from(range(14))) + def test_Union_cast_List(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +@dataclass() +class A(PlutusData): + CONSTR_ID = 0 + x: int + +@dataclass() +class B(PlutusData): + CONSTR_ID = 1 + y: bytes + +def foo(xs: List[Union[A, B]]) -> List[int]: + k: List[int] = [x.x + 1 for x in xs if isinstance(x, A)] + if not k: + k = [len(x.y) for x in xs if isinstance(x, B)] + return k + +def validator(x: int) -> int: + if x > 5: + k = foo([A(x)]) + else: + k = foo([B(b"0"*x)]) + return k[0] +""" + res = eval_uplc_value(source_code, x) + real = x + 1 if x > 5 else len(b"0" * x) + self.assertEqual(res, real) + + @unittest.skip("Throw compilation error, hence not critical") + @hypothesis.given(st.sampled_from(range(14))) + def test_Union_builtin_cast_List(self, x): + source_code = """ +from dataclasses import dataclass +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData + +def foo(xs: List[Union[int, bytes]]) -> List[int]: + k: List[int] = [x + 1 for x in xs if isinstance(x, int)] + if not k: + k = [len(x) for x in xs if isinstance(x, bytes)] + return k + +def validator(x: int) -> int: + if x > 5: + k = foo(x+1) + else: + k = foo(b"0"*x) + return k[0] +""" + res = eval_uplc_value(source_code, x) + real = x + 2 if x > 5 else len(b"0" * x) + self.assertEqual(res, real) diff --git a/tests/test_std/test_fractions.py b/tests/test_std/test_fractions.py index 60330b81..2c09b927 100644 --- a/tests/test_std/test_fractions.py +++ b/tests/test_std/test_fractions.py @@ -1,32 +1,50 @@ import hypothesis import hypothesis.strategies as hst - +from typing import Union from opshin.std import fractions as oc_fractions +from ..utils import eval_uplc, eval_uplc_value import fractions as native_fractions import math as native_math +from uplc.ast import PlutusConstr + non_null = hst.one_of(hst.integers(min_value=1), hst.integers(max_value=-1)) denormalized_fractions = hst.builds(oc_fractions.Fraction, hst.integers(), non_null) denormalized_fractions_non_null = hst.builds(oc_fractions.Fraction, non_null, non_null) +denormalized_fractions_and_int = hst.one_of([denormalized_fractions, hst.integers()]) +denormalized_fractions_and_int_non_null = hst.one_of( + [denormalized_fractions_non_null, non_null] +) -def native_fraction_from_oc_fraction(f: oc_fractions.Fraction): - return native_fractions.Fraction(f.numerator, f.denominator) +def native_fraction_from_oc_fraction(f: Union[oc_fractions.Fraction, int]): + if isinstance(f, oc_fractions.Fraction): + return native_fractions.Fraction(f.numerator, f.denominator) + elif isinstance(f, PlutusConstr): + return native_fractions.Fraction(*[x.value for x in f.fields]) + else: + return f -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_add(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_added = oc_fractions.add_fraction(a, b) +def plutus_to_native(f): + assert isinstance(f, PlutusConstr) + assert f.constructor == 1 + return native_fractions.Fraction(*[field.value for field in f.fields]) + + +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) +def test_add_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_added = a + b oc_normalized = native_fraction_from_oc_fraction(oc_added) assert oc_normalized == ( native_fraction_from_oc_fraction(a) + native_fraction_from_oc_fraction(b) ), "Invalid add" -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_sub(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_subbed = oc_fractions.sub_fraction(a, b) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) +def test_sub_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_subbed = a - b oc_normalized = native_fraction_from_oc_fraction(oc_subbed) assert oc_normalized == ( native_fraction_from_oc_fraction(a) - native_fraction_from_oc_fraction(b) @@ -34,24 +52,24 @@ def test_sub(a: oc_fractions.Fraction, b: oc_fractions.Fraction): @hypothesis.given(denormalized_fractions) -def test_neg(a: oc_fractions.Fraction): - oc_negged = oc_fractions.neg_fraction(a) +def test_neg_dunder(a: oc_fractions.Fraction): + oc_negged = -a oc_normalized = native_fraction_from_oc_fraction(oc_negged) assert oc_normalized == -native_fraction_from_oc_fraction(a), "Invalid neg" -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_mul(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_mulled = oc_fractions.mul_fraction(a, b) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) +def test_mul_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_mulled = a * b oc_normalized = native_fraction_from_oc_fraction(oc_mulled) assert oc_normalized == ( native_fraction_from_oc_fraction(a) * native_fraction_from_oc_fraction(b) ), "Invalid mul" -@hypothesis.given(denormalized_fractions, denormalized_fractions_non_null) -def test_div(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_divved = oc_fractions.div_fraction(a, b) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int_non_null) +def test_div_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_divved = a / b oc_normalized = native_fraction_from_oc_fraction(oc_divved) assert oc_normalized == ( native_fraction_from_oc_fraction(a) / native_fraction_from_oc_fraction(b) @@ -76,47 +94,56 @@ def test_norm(a: oc_fractions.Fraction): assert oc_normed.denominator == oc_normalized.denominator, "Invalid norm" -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_ge(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_ge = oc_fractions.ge_fraction(a, b) +@hypothesis.given(denormalized_fractions) +@hypothesis.example(oc_fractions.Fraction(0, -1)) +def test_norm_method(a: oc_fractions.Fraction): + oc_normed = a.norm() + oc_normalized = native_fraction_from_oc_fraction(a) + assert oc_normed.numerator == oc_normalized.numerator, "Invalid norm" + assert oc_normed.denominator == oc_normalized.denominator, "Invalid norm" + + +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) +def test_ge_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_ge = a >= b ge = native_fraction_from_oc_fraction(a) >= native_fraction_from_oc_fraction(b) assert oc_ge == ge, "Invalid ge" -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_le(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_le = oc_fractions.le_fraction(a, b) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) +def test_le_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_le = a <= b le = native_fraction_from_oc_fraction(a) <= native_fraction_from_oc_fraction(b) assert oc_le == le, "Invalid le" -@hypothesis.given(denormalized_fractions, denormalized_fractions) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) def test_lt(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_lt = oc_fractions.lt_fraction(a, b) + oc_lt = a < b lt = native_fraction_from_oc_fraction(a) < native_fraction_from_oc_fraction(b) assert oc_lt == lt, "Invalid lt" -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_gt(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_gt = oc_fractions.gt_fraction(a, b) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) +def test_gt_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_gt = a > b gt = native_fraction_from_oc_fraction(a) > native_fraction_from_oc_fraction(b) assert oc_gt == gt, "Invalid gt" -@hypothesis.given(denormalized_fractions, denormalized_fractions) -def test_eq(a: oc_fractions.Fraction, b: oc_fractions.Fraction): - oc_eq = oc_fractions.eq_fraction(a, b) +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int) +def test_eq_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_eq = a == b eq = native_fraction_from_oc_fraction(a) == native_fraction_from_oc_fraction(b) assert oc_eq == eq, "Invalid eq" -@hypothesis.given(denormalized_fractions) -def test_floor(a: oc_fractions.Fraction): - oc_floor = oc_fractions.floor_fraction(a) - assert ( - native_math.floor(native_fraction_from_oc_fraction(a)) == oc_floor - ), "Invalid floor" +@hypothesis.given(denormalized_fractions, denormalized_fractions_and_int_non_null) +def test_floor_dunder(a: oc_fractions.Fraction, b: oc_fractions.Fraction): + oc_floor = a // b + floor = native_fraction_from_oc_fraction(a) // native_fraction_from_oc_fraction(b) + + assert oc_floor == floor, "Invalid floor" @hypothesis.given(denormalized_fractions) @@ -125,3 +152,27 @@ def test_ceil(a: oc_fractions.Fraction): assert ( native_math.ceil(native_fraction_from_oc_fraction(a)) == oc_ceil ), "Invalid ceil" + + +@hypothesis.given(denormalized_fractions) +def test_ceil_method(a: oc_fractions.Fraction): + oc_ceil = a.ceil() + assert ( + native_math.ceil(native_fraction_from_oc_fraction(a)) == oc_ceil + ), "Invalid ceil" + + +@hypothesis.given(denormalized_fractions, denormalized_fractions) +def test_uplc(a, b): + source_code = """ +from opshin.std.fractions import * +from typing import Dict, List, Union + +def validator(a: Fraction, b: Union[Fraction, int]) -> Fraction: + return a+b +""" + ret = eval_uplc(source_code, a, b) + print(ret) + assert ( + native_fraction_from_oc_fraction(a) + native_fraction_from_oc_fraction(b) + ) == native_fraction_from_oc_fraction(ret), "invalid add"