Skip to content

Commit

Permalink
Merge pull request #403 from SCMusson/feat_frac_dunder
Browse files Browse the repository at this point in the history
Fractions functionality issue #395
  • Loading branch information
nielstron authored Oct 8, 2024
2 parents fd413f4 + 2eb83db commit 26f7532
Show file tree
Hide file tree
Showing 9 changed files with 506 additions and 129 deletions.
12 changes: 11 additions & 1 deletion opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def visit_Name(self, node: TypedName) -> plt.AST:
if isinstance(node.typ, ClassType):
# if this is not an instance but a class, call the constructor
return node.typ.constr()
if hasattr(node, "is_wrapped") and node.is_wrapped:
return transform_ext_params_map(node.typ)(plt.Force(plt.Var(node.id)))
return plt.Force(plt.Var(node.id))

def visit_Expr(self, node: TypedExpr) -> CallAST:
Expand Down Expand Up @@ -433,7 +435,7 @@ def visit_Call(self, node: TypedCall) -> plt.AST:
assert isinstance(t, InstanceType)
# pass in all arguments evaluated with the statemonad
a_int = self.visit(a)
if isinstance(t.typ, AnyType):
if isinstance(t.typ, AnyType) or isinstance(t.typ, UnionType):
# if the function expects input of generic type data, wrap data before passing it inside
a_int = transform_output_map(a.typ)(a_int)
args.append(a_int)
Expand Down Expand Up @@ -914,6 +916,14 @@ def visit_Dict(self, node: TypedDict) -> plt.AST:
return l

def visit_IfExp(self, node: TypedIfExp) -> plt.AST:
if isinstance(node.typ.typ, UnionType):
body = self.visit(node.body)
orelse = self.visit(node.orelse)
if not isinstance(node.body.typ, UnionType):
body = transform_output_map(node.body.typ)(body)
if not isinstance(node.orelse.typ, UnionType):
orelse = transform_output_map(node.orelse.typ)(orelse)
return plt.Ite(self.visit(node.test), body, orelse)
return plt.Ite(
self.visit(node.test),
self.visit(node.body),
Expand Down
6 changes: 6 additions & 0 deletions opshin/fun_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def type_from_args(self, args: typing.List[Type]) -> FunctionType:
return FunctionType(args, BoolInstanceType)

def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
if not (isinstance(args[0], UnionType) or isinstance(args[0].typ, UnionType)):
if args[0].typ == args[1]:
return OLambda(["x"], plt.Bool(True))
else:
return OLambda(["x"], plt.Bool(False))

if isinstance(args[1], IntegerType):
return OLambda(
["x"],
Expand Down
1 change: 1 addition & 0 deletions opshin/rewrite/rewrite_forbidden_overwrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"List",
"Dict",
"Union",
"Self",
# decorator and class name
"dataclass",
"PlutusData",
Expand Down
10 changes: 10 additions & 0 deletions opshin/rewrite/rewrite_import_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ def visit_ClassDef(self, node: ClassDef) -> ClassDef:
and arg.annotation.id == "Self"
):
node.body[i].args.args[j].annotation.idSelf = node.name
if (
isinstance(arg.annotation, Subscript)
and arg.annotation.value.id == "Union"
):
for k, s in enumerate(arg.annotation.slice.elts):
if isinstance(s, Name) and s.id == "Self":
node.body[i].args.args[j].annotation.slice.elts[
k
].idSelf = node.name

if (
isinstance(attribute.returns, Name)
and attribute.returns.id == "Self"
Expand Down
17 changes: 7 additions & 10 deletions opshin/rewrite/rewrite_scoping.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class RewriteScoping(CompilingNodeTransformer):
step = "Rewrite all variables to inambiguously point to the definition in the nearest enclosing scope"
latest_scope_id: int
scopes: typing.List[typing.Tuple[OrderedSet, int]]
current_Self: typing.Tuple[str, str]

def variable_scope_id(self, name: str) -> int:
"""find the id of the scope in which this variable is defined (closest to its usage)"""
Expand Down Expand Up @@ -86,13 +87,17 @@ def visit_Module(self, node: Module) -> Module:
def visit_Name(self, node: Name) -> Name:
nc = copy(node)
# setting is handled in either enclosing module or function
if node.id == "Self":
assert node.idSelf == self.current_Self[1]
nc.idSelf_new = self.current_Self[0]
nc.id = self.map_name(node.id)
return nc

def visit_ClassDef(self, node: ClassDef) -> ClassDef:
cp_node = RecordScoper.scope(node, self)
for i, attribute in enumerate(cp_node.body):
if isinstance(attribute, FunctionDef):
self.current_Self = (cp_node.name, cp_node.orig_name)
cp_node.body[i] = self.visit_FunctionDef(attribute, method=True)
return cp_node

Expand All @@ -108,17 +113,9 @@ def visit_FunctionDef(self, node: FunctionDef, method: bool = False) -> Function
a_cp = copy(a)
self.set_variable_scope(a.arg)
a_cp.arg = self.map_name(a.arg)
a_cp.annotation = (
self.visit(a.annotation)
if not hasattr(a.annotation, "idSelf")
else a.annotation
)
a_cp.annotation = self.visit(a.annotation)
node_cp.args.args.append(a_cp)
node_cp.returns = (
self.visit(node.returns)
if not hasattr(node.returns, "idSelf")
else node.returns
)
node_cp.returns = self.visit(node.returns)
# vars defined in this scope
shallow_node_def_collector = ShallowNameDefCollector()
for s in node.body:
Expand Down
231 changes: 159 additions & 72 deletions opshin/std/fractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,165 @@ class Fraction(PlutusData):
numerator: int
denominator: int


def add_fraction(a: Fraction, b: Fraction) -> Fraction:
"""returns a + b"""
return Fraction(
(a.numerator * b.denominator) + (b.numerator * a.denominator),
a.denominator * b.denominator,
)


def neg_fraction(a: Fraction) -> Fraction:
"""returns -a"""
return Fraction(-a.numerator, a.denominator)


def sub_fraction(a: Fraction, b: Fraction) -> Fraction:
"""returns a - b"""
return add_fraction(a, neg_fraction(b))


def mul_fraction(a: Fraction, b: Fraction) -> Fraction:
"""returns a * b"""
return Fraction(a.numerator * b.numerator, a.denominator * b.denominator)


def div_fraction(a: Fraction, b: Fraction) -> Fraction:
"""returns a / b"""
return Fraction(a.numerator * b.denominator, a.denominator * b.numerator)
def norm(self) -> "Fraction":
"""Restores the invariant that num/denom are in the smallest possible denomination and denominator > 0"""
return _norm_gcd_fraction(_norm_signs_fraction(self))

def ceil(self) -> int:
return (
self.numerator + self.denominator - sign(self.denominator)
) // self.denominator

def __add__(self, other: Union["Fraction", int]) -> "Fraction":
"""returns self + other"""
if isinstance(other, Fraction):
return Fraction(
(self.numerator * other.denominator)
+ (other.numerator * self.denominator),
self.denominator * other.denominator,
)
else:
return Fraction(
(self.numerator) + (other * self.denominator),
self.denominator,
)

def __neg__(
self,
) -> "Fraction":
"""returns -self"""
return Fraction(-self.numerator, self.denominator)

def __sub__(self, other: Union["Fraction", int]) -> "Fraction":
"""returns self - other"""
if isinstance(other, Fraction):
return Fraction(
(self.numerator * other.denominator)
- (other.numerator * self.denominator),
self.denominator * other.denominator,
)
else:
return Fraction(
self.numerator - (other * self.denominator), self.denominator
)

def __mul__(self, other: Union["Fraction", int]) -> "Fraction":
"""returns self * other"""
if isinstance(other, Fraction):
return Fraction(
self.numerator * other.numerator, self.denominator * other.denominator
)
else:
return Fraction(self.numerator * other, self.denominator)

def __truediv__(self, other: Union["Fraction", int]) -> "Fraction":
"""returns self / other"""
if isinstance(other, Fraction):
return Fraction(
self.numerator * other.denominator, self.denominator * other.numerator
)
else:
return Fraction(self.numerator, self.denominator * other)

def __ge__(self, other: Union["Fraction", int]) -> bool:
"""returns self >= other"""
if isinstance(other, Fraction):
if self.denominator * other.denominator >= 0:
res = (
self.numerator * other.denominator
>= self.denominator * other.numerator
)
else:
res = (
self.numerator * other.denominator
<= self.denominator * other.numerator
)
return res
else:
if self.denominator >= 0:
res = self.numerator >= self.denominator * other
else:
res = self.numerator <= self.denominator * other
return res

def __le__(self, other: Union["Fraction", int]) -> bool:
"""returns self <= other"""
if isinstance(other, Fraction):
if self.denominator * other.denominator >= 0:
res = (
self.numerator * other.denominator
<= self.denominator * other.numerator
)
else:
res = (
self.numerator * other.denominator
>= self.denominator * other.numerator
)
return res
else:
if self.denominator >= 0:
res = self.numerator <= self.denominator * other
else:
res = self.numerator >= self.denominator * other
return res

def __eq__(self, other: Union["Fraction", int]) -> bool:
"""returns self == other"""
if isinstance(other, Fraction):
return (
self.numerator * other.denominator == self.denominator * other.numerator
)
else:
return self.numerator == self.denominator * other

def __lt__(self, other: Union["Fraction", int]) -> bool:
"""returns self < other"""
if isinstance(other, Fraction):
if self.denominator * other.denominator >= 0:
res = (
self.numerator * other.denominator
< self.denominator * other.numerator
)
else:
res = (
self.numerator * other.denominator
> self.denominator * other.numerator
)
return res
else:
if self.denominator >= 0:
res = self.numerator < self.denominator * other
else:
res = self.numerator > self.denominator * other
return res

def __gt__(self, other: Union["Fraction", int]) -> bool:
"""returns self > other"""
if isinstance(other, Fraction):
if self.denominator * other.denominator >= 0:
res = (
self.numerator * other.denominator
> self.denominator * other.numerator
)
else:
res = (
self.numerator * other.denominator
< self.denominator * other.numerator
)
return res
else:
if self.denominator >= 0:
res = self.numerator > self.denominator * other
else:
res = self.numerator < self.denominator * other
return res

def __floordiv__(self, other: Union["Fraction", int]) -> int:
if isinstance(other, Fraction):
x = self / other
return x.numerator // x.denominator
else:
return self.numerator // (other * self.denominator)


def _norm_signs_fraction(a: Fraction) -> Fraction:
Expand All @@ -62,50 +194,5 @@ def norm_fraction(a: Fraction) -> Fraction:
return _norm_gcd_fraction(_norm_signs_fraction(a))


def ge_fraction(a: Fraction, b: Fraction) -> bool:
"""returns a >= b"""
if a.denominator * b.denominator >= 0:
res = a.numerator * b.denominator >= a.denominator * b.numerator
else:
res = a.numerator * b.denominator <= a.denominator * b.numerator
return res


def le_fraction(a: Fraction, b: Fraction) -> bool:
"""returns a <= b"""
if a.denominator * b.denominator >= 0:
res = a.numerator * b.denominator <= a.denominator * b.numerator
else:
res = a.numerator * b.denominator >= a.denominator * b.numerator
return res


def eq_fraction(a: Fraction, b: Fraction) -> bool:
"""returns a == b"""
return a.numerator * b.denominator == a.denominator * b.numerator


def lt_fraction(a: Fraction, b: Fraction) -> bool:
"""returns a < b"""
if a.denominator * b.denominator >= 0:
res = a.numerator * b.denominator < a.denominator * b.numerator
else:
res = a.numerator * b.denominator > a.denominator * b.numerator
return res


def gt_fraction(a: Fraction, b: Fraction) -> bool:
"""returns a > b"""
if a.denominator * b.denominator >= 0:
res = a.numerator * b.denominator > a.denominator * b.numerator
else:
res = a.numerator * b.denominator < a.denominator * b.numerator
return res


def floor_fraction(a: Fraction) -> int:
return a.numerator // a.denominator


def ceil_fraction(a: Fraction) -> int:
return (a.numerator + a.denominator - sign(a.denominator)) // a.denominator
Loading

0 comments on commit 26f7532

Please sign in to comment.