1048 lines
33 KiB
Python
1048 lines
33 KiB
Python
import ast
|
|
import os
|
|
import re
|
|
import sys
|
|
import json
|
|
import hashlib
|
|
import operator
|
|
import functools
|
|
import itertools
|
|
|
|
import ifcopenshell.express
|
|
import ifcopenshell.express.express_parser
|
|
|
|
import networkx as nx
|
|
|
|
from codegen import indent
|
|
|
|
DEBUG = False
|
|
|
|
|
|
def to_graph(tree):
|
|
g = nx.DiGraph()
|
|
|
|
# Convert
|
|
def write_to_graph(val, name=None):
|
|
if isinstance(val, list):
|
|
pairs = ((None, v) for v in val)
|
|
elif isinstance(val, dict):
|
|
pairs = val.items()
|
|
else:
|
|
assert name
|
|
g.add_edge(name, name + "_value")
|
|
return g.add_node(name + "_value", label=val)
|
|
|
|
for i, (k, v) in enumerate(pairs):
|
|
i = f"{i:03d}"
|
|
nid = f"{name or 'root'}/{k or i}"
|
|
g.add_node(nid, label=k)
|
|
if name:
|
|
g.add_edge(name, nid)
|
|
write_to_graph(v, nid)
|
|
|
|
write_to_graph(tree)
|
|
|
|
to_remove = set()
|
|
|
|
# Remove intermediate anonymous nodes. Often the result of ZeroOrMore() productions in
|
|
# bootstrap.py that result in an intermediate list index node in to_tree()
|
|
|
|
# Start with the intermediate nodes and filter out root (needs to have predecessors)
|
|
intermediate = [n for n in g.nodes if g.nodes[n].get("label") is None and list(g.predecessors(n))]
|
|
|
|
for n in intermediate:
|
|
pr = list(g.predecessors(n))
|
|
if len(pr) == 1 and g.nodes[pr[0]].get("label"):
|
|
# when eliminating a grouping node with heterogeneous content
|
|
# rather copy the predecessor node label to the grouping node
|
|
# and later delete the predecessor
|
|
sc = list(g.successors(n))
|
|
if len(sc) > 1:
|
|
sc_labels = list(map(lambda x: g.nodes[x].get("label"), sc))
|
|
if len(set(sc_labels)) > 1 and None not in sc_labels:
|
|
g.nodes[n]["label"] = g.nodes[pr[0]].get("label")
|
|
to_remove.add(pr[0])
|
|
continue
|
|
|
|
for ab in itertools.product(g.predecessors(n), g.successors(n)):
|
|
g.add_edge(*ab)
|
|
g.remove_node(n)
|
|
|
|
# The removal process above can decide to not fold the anonymous node, but rather
|
|
# the predecessor of it, in which case it is deleted in this step.
|
|
for n in to_remove:
|
|
for ab in itertools.product(g.predecessors(n), g.successors(n)):
|
|
g.add_edge(*ab)
|
|
g.remove_node(n)
|
|
|
|
for n in g.nodes:
|
|
if (
|
|
len(list(g.successors(n))) == 0
|
|
and g.nodes[n].get("label") not in ifcopenshell.express.express_parser.all_rules
|
|
):
|
|
g.nodes[n]["is_terminal"] = True
|
|
|
|
return g
|
|
|
|
|
|
def write_dot(fn, g):
|
|
|
|
with open(fn, "w") as f:
|
|
|
|
def w(*args, **kwargs):
|
|
print(*args, file=f, **kwargs)
|
|
|
|
w("digraph", "{")
|
|
|
|
def nodename(n):
|
|
return "N" + hashlib.md5(n.encode()).hexdigest()
|
|
|
|
def format(di):
|
|
Q = '"'
|
|
inner = ",".join(
|
|
f"{k}={'' if v.startswith('<') else Q}{v}{'' if v.startswith('<') else Q}" for k, v in di.items()
|
|
)
|
|
if inner:
|
|
inner = f"[{inner}]"
|
|
return inner
|
|
|
|
for n in g.nodes:
|
|
lbl = g.nodes[n].get("label")
|
|
if lbl:
|
|
attrs = {"label": lbl}
|
|
else:
|
|
attrs = {"label": n}
|
|
|
|
if g.nodes[n].get("is_terminal"):
|
|
attrs["shape"] = "rect"
|
|
attrs["label"] = f"\\\"{attrs['label']}\\\""
|
|
else:
|
|
attrs["shape"] = "none"
|
|
|
|
w(nodename(n), format(attrs), ";", sep="")
|
|
|
|
for a, b in g.edges:
|
|
w(nodename(a), "->", nodename(b), ";")
|
|
|
|
w("}", flush=True)
|
|
|
|
|
|
from pyparsing import *
|
|
|
|
SLASH = Suppress("/")
|
|
identifier = Word(alphanums + "_")
|
|
rule = identifier + (ZeroOrMore(SLASH + identifier))
|
|
|
|
|
|
def paths(G, root, length):
|
|
if length == 1:
|
|
yield (G.nodes[root].get("label"),)
|
|
return
|
|
|
|
sd = dict(nx.bfs_successors(G, root, depth_limit=length - 1))
|
|
|
|
def r(x, p=None):
|
|
if p and len(p) == length:
|
|
yield tuple(map(lambda n: G.nodes[n].get("label"), p))
|
|
else:
|
|
for y in sd.get(x, []):
|
|
yield from r(y, (p or [x]) + [y])
|
|
|
|
yield from r(root)
|
|
|
|
|
|
class context:
|
|
def __init__(self, graph, rules):
|
|
self.graph = graph
|
|
self.rules = rules
|
|
|
|
def __getattr__(self, k):
|
|
def inner():
|
|
for r in self.rules:
|
|
label_id_pairs = map(
|
|
lambda n: (self.graph.nodes[n].get("label"), n),
|
|
# itertools.chain.from_iterable(
|
|
# dict(nx.bfs_successors(self.graph, r)).values()
|
|
# )
|
|
self.graph.successors(r),
|
|
)
|
|
matching = filter(lambda p: p[0] == k, label_id_pairs)
|
|
yield from map(operator.itemgetter(1), matching)
|
|
|
|
return context(self.graph, list(inner()))
|
|
|
|
def has_inverse(self, a):
|
|
for r in self.rules:
|
|
if a in map(lambda n: self.graph.nodes[n].get("label"), self.graph.predecessors(r)):
|
|
return True
|
|
return False
|
|
|
|
def __iter__(self):
|
|
for r in self.rules:
|
|
yield context(self.graph, [r])
|
|
|
|
def descendants(self):
|
|
return [b.rules[0][len(self.rules[0]) + 1 :] for b in self.branches(allow_multiple=True)]
|
|
|
|
def __repr__(self):
|
|
try:
|
|
s = "\n\n" + str(self)
|
|
except:
|
|
s = ""
|
|
return f"<rule_context ({' '.join(self.descendants())})>{s}"
|
|
|
|
def __str__(self):
|
|
assert len(self.rules) == 1
|
|
nodes = itertools.chain(
|
|
self.rules,
|
|
itertools.chain.from_iterable(dict(nx.bfs_successors(self.graph, self.rules[0])).values()),
|
|
)
|
|
terminals_or_values = list(
|
|
filter(
|
|
lambda n: self.graph.nodes[n].get("is_terminal") or self.graph.nodes[n].get("value"),
|
|
nodes,
|
|
)
|
|
)
|
|
# assert len(terminals) == 1
|
|
attrs = [self.graph.nodes[tv] for tv in terminals_or_values]
|
|
attrs = [a.get("value", a["label"]) for a in attrs]
|
|
attr_types = list(map(type, attrs))
|
|
if empty in attr_types[0:1]:
|
|
return ""
|
|
attrs = list(filter(lambda s: isinstance(s, str), attrs))
|
|
return attrs[0]
|
|
|
|
def __eq__(self, other):
|
|
return self.graph == other.graph and self.rules == other.rules
|
|
|
|
def __hash__(self):
|
|
return hash(self.rules)
|
|
|
|
def branches(self, allow_multiple=False, exclude=()):
|
|
if not allow_multiple:
|
|
assert len(self.rules) == 1
|
|
combined = sum(
|
|
[
|
|
sorted(
|
|
(context(self.graph, [n]) for n in self.graph.successors(R)),
|
|
key=lambda c: c.rules[0] if c.rules else "",
|
|
)
|
|
for R in self.rules
|
|
],
|
|
[],
|
|
)
|
|
return [c for c in combined if c not in exclude]
|
|
|
|
def parent(self):
|
|
assert len(self.rules) == 1
|
|
return context(self.graph, list(self.graph.predecessors(self.rules[0])))
|
|
|
|
def branch(self, i):
|
|
return self.branches()[i]
|
|
|
|
def __len__(self):
|
|
return len(self.rules)
|
|
|
|
def __getitem__(self, k):
|
|
return list(self)[k]
|
|
|
|
def key(self):
|
|
parts = list(
|
|
map(
|
|
lambda s: tuple(map(lambda p: "n" if p.isdigit() else p, s.split("/"))),
|
|
self.rules,
|
|
)
|
|
)
|
|
assert len(set(parts)) == 1
|
|
return [x for x in parts[0][::-1] if x != "n"][0]
|
|
|
|
|
|
# @todo
|
|
context_class = context
|
|
|
|
|
|
class codegen_rule:
|
|
def __init__(self, pattern, fn):
|
|
self.pattern = tuple(rule.parseString(pattern))
|
|
self.fn = fn
|
|
if not hasattr(codegen_rule, "all_rules"):
|
|
codegen_rule.all_rules = []
|
|
codegen_rule.all_rules.append(self)
|
|
|
|
def __call__(self, graph, node):
|
|
# try:
|
|
v = self.fn(context(graph, [node]))
|
|
# except:
|
|
# v = "ERROR!!"
|
|
graph.nodes[node]["value"] = v
|
|
return v
|
|
|
|
@staticmethod
|
|
def apply(G):
|
|
v = None
|
|
for n in reversed(list(nx.topological_sort(G))):
|
|
for r in codegen_rule.all_rules:
|
|
if r.pattern in paths(G, n, len(r.pattern)):
|
|
v = r(G, n)
|
|
return v
|
|
|
|
|
|
def process_rule_decl(context):
|
|
return f"""
|
|
class {context.rule_head.rule_id}:
|
|
SCOPE = "file"
|
|
|
|
@staticmethod
|
|
def __call__(file):
|
|
{context.rule_head.entity_ref} = file.by_type("{context.rule_head.entity_ref}")
|
|
{indent(8, context.algorithm_head.local_decl)}
|
|
{indent(8, context.stmt.branches()) if context.stmt else ''}
|
|
{indent(8, context.where_clause.domain_rule)}
|
|
"""
|
|
|
|
|
|
class empty:
|
|
pass
|
|
|
|
|
|
wb = r"\b"
|
|
|
|
|
|
def process_type_decl(scope, context):
|
|
class_name = context.type_id if scope == "type" else context.entity_head.entity_id
|
|
|
|
attributes = []
|
|
|
|
if scope == "entity":
|
|
|
|
def get_attributes(nm):
|
|
ent = schema.entities[nm]
|
|
if ent.supertypes:
|
|
yield from get_attributes(ent.supertypes[0])
|
|
yield from [a.name for a in ent.attributes]
|
|
yield from [a.name for a in ent.inverse]
|
|
# redeclared do not need to be printed, because they're emitted
|
|
# as part of supertype
|
|
yield from [a[0] for a in ent.derive if isinstance(a[0], str)]
|
|
|
|
# @todo derived and inverse attributes
|
|
attributes = list(get_attributes(class_name))
|
|
|
|
def format_rule(domain_rule):
|
|
return f"""
|
|
class {class_name}_{domain_rule.rule_label_id}:
|
|
SCOPE = "{scope}"
|
|
TYPE_NAME = "{class_name}"
|
|
RULE_NAME = "{domain_rule.rule_label_id}"
|
|
|
|
@staticmethod
|
|
def __call__(self):
|
|
{indent(8, (f"{a.lower()} = self.{a}" for a in attributes if re.search(f'{wb}{a.lower()}{wb}', str(domain_rule))))}
|
|
{indent(8, domain_rule)}
|
|
"""
|
|
|
|
rule_parent = context if scope == "type" else context.entity_body
|
|
|
|
statements = []
|
|
|
|
if rule_parent.where_clause:
|
|
# @todo should we not try to maintain a 1-1 correspondence?
|
|
statements.extend(map(format_rule, rule_parent.where_clause.branches()))
|
|
|
|
if scope == "entity":
|
|
|
|
def format_derived(derived_attr):
|
|
slash = "\\"
|
|
return f"""
|
|
def calc_{class_name}_{str(derived_attr.attribute_decl.redeclared_attribute.qualified_attribute.attribute_qualifier)[1:] if derived_attr.attribute_decl.redeclared_attribute else derived_attr.attribute_decl}(self):
|
|
{indent(4, (f"{a.lower()} = self.{a}" for a in attributes if re.search(f'{wb}{a.lower()}{wb}', str(derived_attr.expression))))}
|
|
{indent(4, f"return {slash}")}
|
|
{indent(4, derived_attr.expression)}
|
|
"""
|
|
|
|
if context.entity_body.derive_clause:
|
|
statements.extend(map(format_derived, context.entity_body.derive_clause.branches()))
|
|
|
|
return "\n\n".join(statements)
|
|
|
|
|
|
def process_domain_rule(context):
|
|
return f"""
|
|
assert ({context.expression}) is not False
|
|
"""
|
|
|
|
|
|
def wrap_parens(s):
|
|
s = str(s)
|
|
if " " in s:
|
|
s = "(%s)" % s
|
|
return s
|
|
|
|
|
|
def process_expression(context):
|
|
def concat(a, b, **kwargs):
|
|
return " ".join(
|
|
map(
|
|
str,
|
|
sum(
|
|
zip(
|
|
[None] + a.branches(**kwargs),
|
|
map(wrap_parens, b.branches(**kwargs)),
|
|
),
|
|
(),
|
|
)[1:],
|
|
)
|
|
)
|
|
|
|
if context.rel_op_extended:
|
|
if context.term:
|
|
# IfcSameValue
|
|
return concat(
|
|
context.rel_op_extended,
|
|
context,
|
|
allow_multiple=True,
|
|
exclude=[context.rel_op_extended],
|
|
)
|
|
else:
|
|
if len(context.simple_expression.branches()) == 2 and str(context.rel_op_extended) == "in":
|
|
# IfcBlobTexture
|
|
try:
|
|
is_literal_str_list = set(
|
|
map(type, ast.literal_eval(str(context.simple_expression.branches()[1])))
|
|
) == {str}
|
|
except:
|
|
is_literal_str_list = False
|
|
if is_literal_str_list:
|
|
a, b = map(str, context.simple_expression.branches())
|
|
return f"{a}.lower() {str(context.rel_op_extended)} {b}"
|
|
return concat(context.rel_op_extended, context.simple_expression)
|
|
elif context.multiplication_like_op:
|
|
if str(context.multiplication_like_op.branches()[0]) == "||":
|
|
all_args = {}
|
|
most_concrete_type = None
|
|
most_concrete_type_inheritance_chain_length = -1
|
|
|
|
for s in context.factor.branches():
|
|
typename, args = str(s).split("(", 1)
|
|
args = args[:-1]
|
|
|
|
break_points = [[0]]
|
|
bracket_nesting = 0
|
|
for i, tk in enumerate(args):
|
|
if tk in "[(":
|
|
bracket_nesting += 1
|
|
if tk in ")]":
|
|
bracket_nesting -= 1
|
|
if tk == "," and bracket_nesting == 0:
|
|
break_points[-1].append(i)
|
|
break_points.append([i + 1])
|
|
|
|
break_points[-1].append(len(args))
|
|
|
|
# @todo don't depend on registered schema
|
|
S = ifcopenshell.ifcopenshell_wrapper.schema_by_name(schema.name)
|
|
entity = S.declaration_by_name(typename)
|
|
entity_attributes = entity.attributes()
|
|
|
|
def count_chain_length(ent):
|
|
length = 0
|
|
while ent:
|
|
ent = ent.supertype()
|
|
length += 1
|
|
return length
|
|
|
|
args = [args[slice(*x)] for x in break_points]
|
|
|
|
for i, arg in filter(lambda p: p[1], enumerate(args)):
|
|
all_args[entity_attributes[i].name()] = arg
|
|
|
|
cl = count_chain_length(entity)
|
|
if cl > most_concrete_type_inheritance_chain_length:
|
|
most_concrete_type = entity.name()
|
|
most_concrete_type_inheritance_chain_length = cl
|
|
|
|
return f"{most_concrete_type}({', '.join(f'{a[0]}={a[1]}' for a in all_args.items())})"
|
|
else:
|
|
return concat(context.multiplication_like_op, context.factor)
|
|
elif context.add_like_op:
|
|
if context.factor or len(context.term) > 1:
|
|
# @todo now sure why this is required (in IfcCrossProduct)
|
|
# @todo not sure what's going on here, why we have both factor and term as direct child productions of simple_expression (in IfcDotProduct)
|
|
return concat(
|
|
context.add_like_op,
|
|
context,
|
|
allow_multiple=True,
|
|
exclude=[context.add_like_op],
|
|
)
|
|
else:
|
|
return concat(context.add_like_op, context.term)
|
|
|
|
|
|
def process_interval(context):
|
|
op0, op1 = context.interval_op.branches()
|
|
return " ".join(
|
|
map(
|
|
str,
|
|
(
|
|
context.interval_low,
|
|
op0,
|
|
context.interval_item,
|
|
op1,
|
|
context.interval_high,
|
|
),
|
|
)
|
|
)
|
|
|
|
|
|
def simple_concat(context):
|
|
# simple_factor:
|
|
# only to join unary op (-) with number literal
|
|
# primary:
|
|
# only to join index with qualifyable operand
|
|
|
|
def qualifier_position(s):
|
|
# @todo this is a really ugly hack, can we not depend on stable branch order and why?
|
|
|
|
# unary operators
|
|
if s in ("-", "+", "not"):
|
|
return -1
|
|
|
|
# qualifiers
|
|
if s and s[0] in (".", "["):
|
|
return 1
|
|
|
|
# default
|
|
return 0
|
|
|
|
branches = sorted(map(str, context.branches()), key=qualifier_position)
|
|
|
|
# sorting no longer necessary as we sort in branches() now
|
|
# correction: still necessary, apparently.
|
|
# branches = list(map(str, context.branches()))
|
|
|
|
if len(branches) == 2 and branches[0] == "not":
|
|
return f"{branches[0]} {wrap_parens(branches[1])}"
|
|
else:
|
|
v = "".join(branches)
|
|
|
|
return v
|
|
|
|
|
|
def process_rel_op(context):
|
|
# @todo the distinction between value comparison and instance comparison
|
|
if str(context) == "<>" or str(context) == ":<>:":
|
|
return "!="
|
|
elif str(context) == "=" or str(context) == ":=:":
|
|
return "=="
|
|
|
|
|
|
def process_if_stmt(context):
|
|
s = f"if {context.logical_expression if context.logical_expression.branches() else context.expression}:\n{indent(4, context.stmt.branches())}"
|
|
if context.else_stmt:
|
|
s += f"\nelse:\n{indent(4, context.else_stmt.branches())}"
|
|
return s
|
|
|
|
|
|
def process_repeat_stmt(context):
|
|
ic = context.repeat_control.increment_control
|
|
return f"for {ic.variable_id} in range({ic.bound_1}, {ic.bound_2} + 1):\n{indent(4, context.stmt.branches())}"
|
|
|
|
|
|
def process_function_decl(context):
|
|
arguments = map(
|
|
str.lower,
|
|
map(
|
|
str,
|
|
context.function_head.formal_parameter.parameter_id.branches(allow_multiple=True),
|
|
),
|
|
)
|
|
return f"def {context.function_head.function_id}({', '.join(arguments)}):\n{indent(4, context.algorithm_head.local_decl)}\n{indent(4, context.stmt.branches())}"
|
|
|
|
|
|
def process_query(context):
|
|
return f"[{str(context.variable_id).lower()} for {str(context.variable_id).lower()} in {context.aggregate_source} if {context.logical_expression if context.logical_expression and context.logical_expression.branches() else context.expression}]"
|
|
|
|
|
|
def process_local_variable(context):
|
|
if context.expression:
|
|
expr = str(context.expression)
|
|
if context.parameter_type.generalized_types.general_aggregation_types.general_set_type:
|
|
expr = re.sub(r"(\[[^\]]*\])", "express_set(\\1)", expr)
|
|
|
|
return "%s = %s" % (str(context.variable_id).lower(), expr)
|
|
else:
|
|
return empty()
|
|
|
|
|
|
def process_function_call(context):
|
|
nm = f"{context.built_in_function if context.built_in_function else context.function_ref}"
|
|
args = f"{context.actual_parameter_list if context.actual_parameter_list and context.actual_parameter_list.branches() else ''}"
|
|
if nm == "exists" and "[" in args:
|
|
# exists check if it receives a callable to catch IndexError, because express semantics
|
|
# dictate that out of bounds index returned unknown (IfcTypeObject_WR1)
|
|
wrap = "lambda: "
|
|
else:
|
|
wrap = ""
|
|
return f"{nm}({wrap}{args})"
|
|
|
|
|
|
def make_lowercase(context):
|
|
return str(context).lower()
|
|
|
|
|
|
def make_lowercase_if(fn):
|
|
def inner(context):
|
|
if fn(context):
|
|
return make_lowercase(context)
|
|
|
|
return inner
|
|
|
|
|
|
def process_assignment(context):
|
|
lhs = str(context.general_ref)
|
|
if context.qualifier:
|
|
lhs += str(context.qualifier)
|
|
if m := re.match(r"^([^\[]+)\[([^\[]+)\]$", lhs):
|
|
# @todo ugly regex hack
|
|
aggr, index = m.groups()
|
|
return f"temp = list({aggr})\ntemp[{index}] = {context.expression}\n{aggr} = temp"
|
|
else:
|
|
return "%s = %s" % (lhs, context.expression)
|
|
|
|
|
|
def process_case_action(context):
|
|
first = context.parent().branches().index(context)
|
|
pred = "elif" if first else "if"
|
|
if re.match(r"^'[a-z0-9]+'$", str(context.expression)):
|
|
# @todo this is yet again an ugly hack
|
|
lower = ".lower()"
|
|
else:
|
|
lower = ""
|
|
return f"{pred} {context.parent().expression}{lower} == {context.expression}:\n{indent(4, context.stmt.branches())}"
|
|
|
|
|
|
def process_case_statement(context):
|
|
branches = context.branches(
|
|
exclude=[getattr(context, v) for v in context.descendants() if not v.startswith("case_action")]
|
|
)
|
|
if context.stmt and context.stmt.branches():
|
|
branches += [f"else:\n{indent(4, context.stmt)}"]
|
|
return "\n".join(map(str, branches))
|
|
|
|
|
|
def process_aggregate_initializer(context):
|
|
if context.element.repetition:
|
|
return "([%s] * %s)" % (context.element.expression, context.element.repetition)
|
|
else:
|
|
return "[%s]" % ",".join(map(str, context.element.branches() if context.element else ()))
|
|
|
|
|
|
def process_index(context):
|
|
if context.parent().key() == "index_qualifier":
|
|
return context
|
|
else:
|
|
return "[%s - EXPRESS_ONE_BASED_INDEXING]" % context
|
|
|
|
|
|
# implemented sizeof() function in generated code
|
|
# codegen_rule("built_in_function/SIZEOF", lambda context: f"len")
|
|
# @todo
|
|
codegen_rule("function_call", process_function_call)
|
|
codegen_rule(
|
|
"actual_parameter_list",
|
|
lambda context: ",".join(map(str, context.expression.branches() if context.expression else [])),
|
|
)
|
|
codegen_rule("entity_decl", functools.partial(process_type_decl, "entity"))
|
|
codegen_rule("rule_decl", process_rule_decl)
|
|
codegen_rule("type_decl", functools.partial(process_type_decl, "type"))
|
|
codegen_rule("function_decl", process_function_decl)
|
|
codegen_rule("domain_rule", process_domain_rule)
|
|
codegen_rule("expression", process_expression)
|
|
codegen_rule("simple_expression", process_expression)
|
|
codegen_rule("logical_expression", process_expression)
|
|
codegen_rule("term", process_expression)
|
|
codegen_rule("query_expression", process_query)
|
|
codegen_rule("aggregate_initializer", process_aggregate_initializer)
|
|
codegen_rule("interval", process_interval)
|
|
codegen_rule("simple_factor", simple_concat)
|
|
codegen_rule("primary", simple_concat)
|
|
codegen_rule("qualifier", simple_concat)
|
|
codegen_rule("return_stmt", lambda context: "return %s" % context)
|
|
codegen_rule("compound_stmt", lambda context: "\n".join(map(str, context.stmt.branches())))
|
|
codegen_rule("if_stmt", process_if_stmt)
|
|
codegen_rule("repeat_stmt", process_repeat_stmt)
|
|
# codegen_rule("index", lambda context: '**express_index(%s)' % context)
|
|
codegen_rule("index", process_index)
|
|
codegen_rule("index_qualifier", process_index)
|
|
codegen_rule("group_qualifier", lambda context: empty())
|
|
codegen_rule("attribute_qualifier", lambda context: ".%s" % context)
|
|
codegen_rule("rel_op", process_rel_op)
|
|
codegen_rule("built_in_constant", lambda context: "None" if str(context) == "?" else str(context))
|
|
codegen_rule("assignment_stmt", process_assignment)
|
|
codegen_rule("local_variable", process_local_variable)
|
|
codegen_rule("local_decl", lambda context: "\n".join(map(str, context.branches())))
|
|
codegen_rule("general_ref/parameter_ref", make_lowercase)
|
|
codegen_rule(
|
|
"qualifiable_factor/attribute_ref",
|
|
make_lowercase_if(lambda context: str(context) not in set(map(str, schema.all_declarations.keys()))),
|
|
)
|
|
codegen_rule("case_action", process_case_action)
|
|
codegen_rule("case_stmt", process_case_statement)
|
|
codegen_rule("escape_stmt", lambda context: "break")
|
|
|
|
codegen_rule("XOR", lambda context: "^")
|
|
codegen_rule("MOD", lambda context: "%")
|
|
codegen_rule("TRUE", lambda context: "True")
|
|
codegen_rule("FALSE", lambda context: "False")
|
|
|
|
|
|
class AttributeGetattrTransformer(ast.NodeTransformer):
|
|
def visit_Attribute(self, node):
|
|
parents = []
|
|
n = node
|
|
while n := getattr(n, "parent", 0):
|
|
parents.append(n)
|
|
|
|
custom_funcs = "is_entity", "usedin", "express_len", "express_getitem", "typeof", "express_getattr"
|
|
function_defs = [p.name for p in parents if isinstance(p, ast.FunctionDef)]
|
|
if any(fn in function_defs for fn in custom_funcs):
|
|
return node
|
|
|
|
# Check if the Attribute node is the target of an assignment statement
|
|
if isinstance(node.ctx, ast.Store):
|
|
return node
|
|
|
|
if node.attr == "create_entity":
|
|
return node
|
|
|
|
if node.attr.startswith("__"):
|
|
return node
|
|
|
|
# Don't rewrite at module scope (top-level, no indent)
|
|
enclosing_stmt = next((p for p in parents if isinstance(p, ast.stmt)), None)
|
|
if enclosing_stmt is not None and isinstance(getattr(enclosing_stmt, "parent", None), ast.Module):
|
|
return node
|
|
|
|
new_value = self.visit(node.value)
|
|
|
|
# Replace the Attribute node with a call to the built-in `getattr` function
|
|
return ast.copy_location(
|
|
ast.Call(
|
|
func=ast.Name(id="express_getattr", ctx=ast.Load()),
|
|
args=[
|
|
new_value,
|
|
ast.Str(s=node.attr),
|
|
ast.Name(id="INDETERMINATE", ctx=ast.Load()),
|
|
],
|
|
keywords=[],
|
|
),
|
|
node,
|
|
)
|
|
|
|
def visit_Subscript(self, node):
|
|
parents = []
|
|
n = node
|
|
while n := getattr(n, "parent", 0):
|
|
parents.append(n)
|
|
|
|
custom_funcs = "is_entity", "usedin", "express_len", "express_getitem", "typeof", "express_getattr"
|
|
function_defs = [p.name for p in parents if isinstance(p, ast.FunctionDef)]
|
|
if any(fn in function_defs for fn in custom_funcs):
|
|
return node
|
|
|
|
# Check if the Attribute node is the target of an assignment statement
|
|
if isinstance(node.ctx, ast.Store):
|
|
return node
|
|
|
|
assert (
|
|
isinstance(node.slice, ast.Name)
|
|
or isinstance(node.slice, ast.Constant)
|
|
or isinstance(node.slice, ast.BinOp)
|
|
)
|
|
|
|
new_value = self.visit(node.value)
|
|
|
|
# Replace the Attribute node with a call to the built-in `getattr` function
|
|
return ast.copy_location(
|
|
ast.Call(
|
|
func=ast.Name(id="express_getitem", ctx=ast.Load()),
|
|
args=[
|
|
new_value,
|
|
node.slice,
|
|
ast.Name(id="INDETERMINATE", ctx=ast.Load()),
|
|
],
|
|
keywords=[],
|
|
),
|
|
node,
|
|
)
|
|
|
|
def assign_parent_refs(self, tree):
|
|
for node in ast.walk(tree):
|
|
for child in ast.iter_child_nodes(node):
|
|
child.parent = node
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import io
|
|
import sys
|
|
import shutil
|
|
import subprocess
|
|
|
|
schema = ifcopenshell.express.express_parser.parse(sys.argv[1]).schema
|
|
|
|
try:
|
|
ifcopenshell.ifcopenshell_wrapper.schema_by_name(schema.name)
|
|
except:
|
|
# @nb note the difference here between:
|
|
#
|
|
# - ifcopenshell.express.express_parser.parse
|
|
# - ifcopenshell.express.parse.parse
|
|
#
|
|
# First generates a pyparsing AST
|
|
#
|
|
# Second populates a latebound schema
|
|
# that can be registered in C++.
|
|
builder = ifcopenshell.express.parse(sys.argv[1])
|
|
ifcopenshell.register_schema(builder)
|
|
|
|
try:
|
|
ofn = sys.argv[2]
|
|
except IndexError as e:
|
|
ofn = os.path.join(os.path.dirname(__file__), "rules", f"{schema.name}.py")
|
|
output = io.StringIO()
|
|
|
|
print("import ifcopenshell", file=output, sep="\n")
|
|
|
|
print(
|
|
"""
|
|
def is_indeterminate(v):
|
|
return v is None or type(v).__name__ == 'indeterminate_type'
|
|
|
|
def exists(v):
|
|
if callable(v):
|
|
try: return v() is not None
|
|
except IndexError as e: return False
|
|
else: return not is_indeterminate(v)
|
|
""",
|
|
"\n",
|
|
file=output,
|
|
sep="\n",
|
|
)
|
|
print(
|
|
"def nvl(v, default): return v if not is_indeterminate(v) else default",
|
|
"\n",
|
|
file=output,
|
|
sep="\n",
|
|
)
|
|
|
|
print(
|
|
"""
|
|
def is_entity(inst):
|
|
if isinstance(inst, ifcopenshell.entity_instance):
|
|
schema_name = inst.is_a(True).split('.')[0].lower()
|
|
decl = ifcopenshell.ifcopenshell_wrapper.schema_by_name(schema_name).declaration_by_name(inst.is_a())
|
|
return isinstance(decl, ifcopenshell.ifcopenshell_wrapper.entity)
|
|
return False
|
|
|
|
def express_len(v):
|
|
if isinstance(v, ifcopenshell.entity_instance) and not is_entity(v):
|
|
v = v[0]
|
|
elif is_indeterminate(v):
|
|
return INDETERMINATE
|
|
return len(v)
|
|
|
|
old_range = range
|
|
|
|
def range(*args):
|
|
if any(map(is_indeterminate, args)):
|
|
return
|
|
yield from old_range(*args)
|
|
|
|
sizeof = express_len
|
|
hiindex = express_len
|
|
blength = express_len
|
|
""",
|
|
file=output,
|
|
sep="\n",
|
|
)
|
|
|
|
print("loindex = lambda x: 1", file=output, sep="\n")
|
|
print("from math import *", file=output, sep="\n")
|
|
|
|
# @todo this will get us in trouble when evaluating the truthness
|
|
print("unknown = 'UNKNOWN'", file=output, sep="\n")
|
|
|
|
print(
|
|
"""
|
|
def usedin(inst, ref_name):
|
|
if inst is None:
|
|
return []
|
|
_, __, attr = ref_name.split('.')
|
|
def filter():
|
|
for ref, attr_idx in inst.wrapped_data.file.get_inverse(inst, allow_duplicate=True, with_attribute_indices=True):
|
|
if ref.wrapped_data.get_attribute_names()[attr_idx].lower() == attr:
|
|
yield ref
|
|
return list(filter())
|
|
|
|
|
|
class express_set(set):
|
|
def __mul__(self, other):
|
|
return express_set(set(other) & self)
|
|
__rmul__ = __mul__
|
|
def __add__(self, other):
|
|
def make_list(v):
|
|
# Comply with 12.6.3 Union operator
|
|
if isinstance(v, (list, tuple, set, express_set)):
|
|
return list(v)
|
|
else:
|
|
return [v]
|
|
return express_set(list(self) + make_list(other))
|
|
__radd__ = __add__
|
|
def __repr__(self):
|
|
return repr(set(self))
|
|
def __getitem__(self, k):
|
|
# @todo this is obviously not stable, but should be good enough?
|
|
return list(self)[k]
|
|
|
|
|
|
def express_getitem(aggr, idx, default):
|
|
if aggr is None: return default
|
|
if isinstance(aggr, ifcopenshell.entity_instance) and not is_entity(aggr):
|
|
aggr = aggr[0]
|
|
try: return aggr[idx]
|
|
except IndexError as e: return None
|
|
|
|
|
|
def express_getattr(aggr, name, default):
|
|
v = getattr(aggr, name, default)
|
|
if v is None:
|
|
return default
|
|
else:
|
|
return v
|
|
|
|
|
|
EXPRESS_ONE_BASED_INDEXING = 1
|
|
|
|
|
|
def typeof(inst):
|
|
if not inst:
|
|
# If V evaluates to indeterminate (?), an empty set is returned.
|
|
return express_set([])
|
|
schema_name = inst.is_a(True).split('.')[0].lower()
|
|
def inner():
|
|
decl = ifcopenshell.ifcopenshell_wrapper.schema_by_name(schema_name).declaration_by_name(inst.is_a())
|
|
while decl:
|
|
yield '.'.join((schema_name, decl.name().lower()))
|
|
if isinstance(decl, ifcopenshell.ifcopenshell_wrapper.entity):
|
|
decl = decl.supertype()
|
|
else:
|
|
decl = decl.declared_type()
|
|
while isinstance(decl, ifcopenshell.ifcopenshell_wrapper.named_type):
|
|
decl = decl.declared_type()
|
|
if not isinstance(decl, ifcopenshell.ifcopenshell_wrapper.type_declaration):
|
|
break
|
|
return express_set(inner())
|
|
|
|
class indeterminate_type:
|
|
def __bool__(self):
|
|
return False
|
|
def bop(self, *other):
|
|
return self
|
|
__lt__= bop
|
|
__le__= bop
|
|
__eq__= bop
|
|
__ne__= bop
|
|
__gt__= bop
|
|
__ge__= bop
|
|
__add__= bop
|
|
__radd__= bop
|
|
__sub__= bop
|
|
__rsub__= bop
|
|
__mul__= bop
|
|
__rmul__= bop
|
|
__truediv__= bop
|
|
__floordiv__= bop
|
|
__rtruediv__= bop
|
|
__rfloordiv__= bop
|
|
__mod__= bop
|
|
__rmod__= bop
|
|
__pow__= bop
|
|
__rpow__= bop
|
|
__neg__= bop
|
|
__pos__= bop
|
|
__getitem__ = bop
|
|
__getattr__ = bop
|
|
def __iter__(self):
|
|
return iter(())
|
|
|
|
INDETERMINATE = indeterminate_type()
|
|
|
|
""",
|
|
file=output,
|
|
sep="\n",
|
|
)
|
|
|
|
print(
|
|
"class enum_namespace:\n def __getattr__(self, k):\n return k.upper()",
|
|
"\n",
|
|
file=output,
|
|
sep="\n",
|
|
)
|
|
|
|
for k, v in schema.enumerations.items():
|
|
print(f"{k} = enum_namespace()", "\n", file=output, sep="\n")
|
|
|
|
for vi in v.values:
|
|
print(f"{vi.lower()} = {k}.{vi}", "\n", file=output, sep="\n")
|
|
|
|
for k in schema.entities.keys():
|
|
print(
|
|
f"def {k}(*args, **kwargs): return ifcopenshell.create_entity({k!r}, {schema.name!r}, *args, **kwargs)",
|
|
"\n",
|
|
file=output,
|
|
sep="\n",
|
|
)
|
|
|
|
for nm in schema.all_declarations.keys():
|
|
print(nm)
|
|
|
|
tree = ifcopenshell.express.express_parser.to_tree(schema[nm])
|
|
|
|
if DEBUG:
|
|
with open(f"{nm}.json", "w") as f:
|
|
json.dump(tree, f, indent=2)
|
|
|
|
G = to_graph(tree)
|
|
rule_code = codegen_rule.apply(G)
|
|
|
|
if DEBUG:
|
|
for n in G.nodes.values():
|
|
if v := n.get("value"):
|
|
if isinstance(v, str):
|
|
nl = "\n"
|
|
es = "\\n"
|
|
n["label"] = (
|
|
f'<<table cellborder="0" cellpadding="0"><tr><td><b>{n.get("label")}</b></td></tr><tr><td align="left" balign="left">{v.replace("<", "<").replace(">", ">").replace(nl, "<br/>")}</td></tr></table>>'
|
|
)
|
|
elif isinstance(v, empty):
|
|
n["label"] = (
|
|
f'<<table cellborder="0" cellpadding="0"><tr><td><b>{n.get("label")}</b></td></tr><tr><td align="left" balign="left">---</td></tr></table>>'
|
|
)
|
|
|
|
fn = f"{nm}.dot"
|
|
write_dot(fn, G)
|
|
subprocess.call([shutil.which("dot") or "dot", fn, "-O", "-Tpng"])
|
|
|
|
print(rule_code, "\n", file=output, sep="\n")
|
|
|
|
tree = ast.parse(output.getvalue())
|
|
trsf = AttributeGetattrTransformer()
|
|
trsf.assign_parent_refs(tree)
|
|
trsf.visit(tree)
|
|
|
|
if ofn == "-":
|
|
print(ast.unparse(tree))
|
|
else:
|
|
with open(ofn, "w") as f:
|
|
f.write(ast.unparse(tree))
|