Skip to content

Commit d5d1eb7

Browse files
committed
Swift: add structured C++ generated classes
This adds `cppgen`, creating structured C++ classes mirroring QL classes out of `schema.yml`. An example of generated code at the time of this commit can be found [in this gist][1]. [1]: https://gist.github.com/redsun82/57304ddb487a8aa40eaa0caa695048fa Closes github/codeql-c-team#863
1 parent 10c5c8e commit d5d1eb7

21 files changed

+445
-60
lines changed

swift/codegen/BUILD.bazel

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
load("@swift_codegen_deps//:requirements.bzl", "requirement")
22

3+
filegroup(
4+
name = "schema",
5+
srcs = ["schema.yml"],
6+
visibility = ["//swift:__subpackages__"],
7+
)
8+
9+
filegroup(
10+
name = "schema_includes",
11+
srcs = glob(["*.dbscheme"]),
12+
visibility = ["//swift:__subpackages__"],
13+
)
14+
315
py_binary(
416
name = "codegen",
517
srcs = glob(
@@ -15,6 +27,17 @@ py_binary(
1527
py_binary(
1628
name = "trapgen",
1729
srcs = ["trapgen.py"],
30+
data = ["//swift/codegen/templates:trap"],
31+
visibility = ["//swift:__subpackages__"],
32+
deps = [
33+
"//swift/codegen/lib",
34+
requirement("toposort"),
35+
],
36+
)
37+
38+
py_binary(
39+
name = "cppgen",
40+
srcs = ["cppgen.py"],
1841
data = ["//swift/codegen/templates:cpp"],
1942
visibility = ["//swift:__subpackages__"],
2043
deps = [

swift/codegen/cppgen.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import functools
2+
import inflection
3+
from typing import Dict
4+
5+
from toposort import toposort_flatten
6+
7+
from swift.codegen.lib import cpp, generator, schema
8+
9+
10+
def _get_type(t: str) -> str:
11+
if t == "string":
12+
return "std::string"
13+
if t == "boolean":
14+
return "bool"
15+
if t[0].isupper():
16+
return f"TrapLabel<{t}Tag>"
17+
return t
18+
19+
20+
def _get_field(cls: schema.Class, p: schema.Property) -> cpp.Field:
21+
trap_name = None
22+
if not p.is_single:
23+
trap_name = inflection.pluralize(inflection.camelize(f"{cls.name}_{p.name}")) + "Trap"
24+
args = dict(
25+
name=p.name + ("_" if p.name in cpp.cpp_keywords else ""),
26+
type=_get_type(p.type),
27+
is_optional=p.is_optional,
28+
is_repeated=p.is_repeated,
29+
trap_name=trap_name,
30+
)
31+
args.update(cpp.get_field_override(p.name))
32+
return cpp.Field(**args)
33+
34+
35+
class Processor:
36+
def __init__(self, data: Dict[str, schema.Class]):
37+
self._classmap = data
38+
39+
@functools.cache
40+
def _get_class(self, name: str) -> cpp.Class:
41+
cls = self._classmap[name]
42+
trap_name = None
43+
if not cls.derived or any(p.is_single for p in cls.properties):
44+
trap_name = inflection.pluralize(cls.name) + "Trap"
45+
return cpp.Class(
46+
name=name,
47+
bases=[self._get_class(b) for b in cls.bases],
48+
fields=[_get_field(cls, p) for p in cls.properties],
49+
final=not cls.derived,
50+
trap_name=trap_name,
51+
)
52+
53+
def get_classes(self):
54+
inheritance_graph = {k: cls.bases for k, cls in self._classmap.items()}
55+
return [self._get_class(cls) for cls in toposort_flatten(inheritance_graph)]
56+
57+
58+
def generate(opts, renderer):
59+
processor = Processor({cls.name: cls for cls in schema.load(opts.schema).classes})
60+
out = opts.cpp_output
61+
renderer.render(cpp.ClassList(processor.get_classes()), out / "TrapClasses.h")
62+
63+
64+
tags = ("cpp", "schema")
65+
66+
if __name__ == "__main__":
67+
generator.run()

swift/codegen/lib/cpp.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from dataclasses import dataclass, field
23
from typing import List, ClassVar
34

@@ -14,13 +15,35 @@
1415
"typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile", "wchar_t", "while",
1516
"xor", "xor_eq"}
1617

18+
_field_overrides = [
19+
(re.compile(r"(start|end)_(line|column)|index|num_.*"), {"type": "unsigned"}),
20+
(re.compile(r"(.*)_"), lambda m: {"name": m[1]}),
21+
]
22+
23+
24+
def get_field_override(field: str):
25+
for r, o in _field_overrides:
26+
m = r.fullmatch(field)
27+
if m:
28+
return o(m) if callable(o) else o
29+
return {}
30+
1731

1832
@dataclass
1933
class Field:
2034
name: str
2135
type: str
36+
is_optional: bool = False
37+
is_repeated: bool = False
38+
trap_name: str = None
2239
first: bool = False
2340

41+
def __post_init__(self):
42+
if self.is_optional:
43+
self.type = f"std::optional<{self.type}>"
44+
elif self.is_repeated:
45+
self.type = f"std::vector<{self.type}>"
46+
2447
@property
2548
def cpp_name(self):
2649
if self.name in cpp_keywords:
@@ -36,6 +59,12 @@ def get_streamer(self):
3659
else:
3760
return lambda x: x
3861

62+
@property
63+
def is_single(self):
64+
return not (self.is_optional or self.is_repeated)
65+
66+
67+
3968

4069
@dataclass
4170
class Trap:
@@ -74,13 +103,48 @@ def has_bases(self):
74103

75104
@dataclass
76105
class TrapList:
77-
template: ClassVar = 'cpp_traps'
106+
template: ClassVar = 'trap_traps'
78107

79108
traps: List[Trap] = field(default_factory=list)
80109

81110

82111
@dataclass
83112
class TagList:
84-
template: ClassVar = 'cpp_tags'
113+
template: ClassVar = 'trap_tags'
85114

86115
tags: List[Tag] = field(default_factory=list)
116+
117+
118+
@dataclass
119+
class ClassBase:
120+
ref: 'Class'
121+
first: bool = False
122+
123+
124+
@dataclass
125+
class Class:
126+
name: str
127+
bases: List[ClassBase] = field(default_factory=list)
128+
final: bool = False
129+
fields: List[Field] = field(default_factory=list)
130+
trap_name: str = None
131+
132+
def __post_init__(self):
133+
self.bases = [ClassBase(c) for c in sorted(self.bases, key=lambda cls: cls.name)]
134+
if self.bases:
135+
self.bases[0].first = True
136+
137+
@property
138+
def has_bases(self):
139+
return bool(self.bases)
140+
141+
@property
142+
def single_fields(self):
143+
return [f for f in self.fields if f.is_single]
144+
145+
146+
@dataclass
147+
class ClassList:
148+
template: ClassVar = "cpp_classes"
149+
150+
classes: List[Class]

swift/codegen/lib/options.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def _init_options():
1515
Option("--ql-output", tags=["ql"], type=_abspath, default=paths.swift_dir / "ql/lib/codeql/swift/generated")
1616
Option("--ql-stub-output", tags=["ql"], type=_abspath, default=paths.swift_dir / "ql/lib/codeql/swift/elements")
1717
Option("--codeql-binary", tags=["ql"], default="codeql")
18-
Option("--trap-output", tags=["trap"], type=_abspath, required=True)
18+
Option("--cpp-output", tags=["cpp"], type=_abspath, required=True)
1919

2020

2121
def _abspath(x):
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
package(default_visibility = ["//swift:__subpackages__"])
2+
3+
filegroup(
4+
name = "trap",
5+
srcs = glob(["trap_*.mustache"]),
6+
)
7+
18
filegroup(
29
name = "cpp",
310
srcs = glob(["cpp_*.mustache"]),
4-
visibility = ["//swift:__subpackages__"],
511
)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// generated by {{generator}}
2+
// clang-format off
3+
#pragma once
4+
5+
#include <iostream>
6+
#include <optional>
7+
#include <vector>
8+
9+
#include "swift/extractor/trap/TrapLabel.h"
10+
#include "swift/extractor/trap/TrapEntries.h"
11+
12+
namespace codeql {
13+
{{#classes}}
14+
15+
struct {{name}}{{#final}} : Binding<{{name}}Tag>{{#bases}}, {{ref.name}}{{/bases}}{{/final}}{{^final}}{{#has_bases}}: {{#bases}}{{^first}}, {{/first}}{{ref.name}}{{/bases}}{{/has_bases}}{{/final}} {
16+
{{#fields}}
17+
{{type}} {{name}}{};
18+
{{/fields}}
19+
{{#final}}
20+
21+
friend std::ostream& operator<<(std::ostream& out, const {{name}}& x) {
22+
x.emit(out);
23+
return out;
24+
}
25+
{{/final}}
26+
27+
protected:
28+
void emit({{^final}}TrapLabel<{{name}}Tag> id, {{/final}}std::ostream& out) const {
29+
{{#bases}}
30+
{{ref.name}}::emit(id, out);
31+
{{/bases}}
32+
{{#trap_name}}
33+
out << {{.}}{id{{#single_fields}}, {{name}}{{/single_fields}}} << '\n';
34+
{{/trap_name}}
35+
{{#fields}}
36+
{{#is_optional}}
37+
if ({{name}}) out << {{trap_name}}{id, *{{name}}} << '\n';
38+
{{/is_optional}}
39+
{{#is_repeated}}
40+
for (auto i = 0u; i < {{name}}.size(); ++i) out << {{trap_name}}{id, i, {{name}}[i]};
41+
{{/is_repeated}}
42+
{{/fields}}
43+
}
44+
};
45+
{{/classes}}
46+
}
File renamed without changes.
File renamed without changes.

swift/codegen/test/test_cpp.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,27 @@ def test_field_get_streamer(type, expected):
2727
assert f.get_streamer()("value") == expected
2828

2929

30+
@pytest.mark.parametrize("is_optional,is_repeated,expected", [
31+
(False, False, True),
32+
(True, False, False),
33+
(False, True, False),
34+
(True, True, False),
35+
])
36+
def test_field_is_single(is_optional, is_repeated, expected):
37+
f = cpp.Field("name", "type", is_optional=is_optional, is_repeated=is_repeated)
38+
assert f.is_single is expected
39+
40+
41+
@pytest.mark.parametrize("is_optional,is_repeated,expected", [
42+
(False, False, "bar"),
43+
(True, False, "std::optional<bar>"),
44+
(False, True, "std::vector<bar>"),
45+
])
46+
def test_field_modal_types(is_optional, is_repeated, expected):
47+
f = cpp.Field("name", "bar", is_optional=is_optional, is_repeated=is_repeated)
48+
assert f.type == expected
49+
50+
3051
def test_trap_has_first_field_marked():
3152
fields = [
3253
cpp.Field("a", "x"),
@@ -56,5 +77,39 @@ def test_tag_has_bases(bases, expected):
5677
assert t.has_bases is expected
5778

5879

80+
def test_class_has_first_base_marked():
81+
bases = [
82+
cpp.Class("a"),
83+
cpp.Class("b"),
84+
cpp.Class("c"),
85+
]
86+
expected = [cpp.ClassBase(c) for c in bases]
87+
expected[0].first = True
88+
c = cpp.Class("foo", bases=bases)
89+
assert c.bases == expected
90+
91+
92+
@pytest.mark.parametrize("bases,expected", [
93+
([], False),
94+
(["a"], True),
95+
(["a", "b"], True)
96+
])
97+
def test_class_has_bases(bases, expected):
98+
t = cpp.Class("name", [cpp.Class(b) for b in bases])
99+
assert t.has_bases is expected
100+
101+
102+
def test_class_single_fields():
103+
fields = [
104+
cpp.Field("a", "A"),
105+
cpp.Field("b", "B", is_optional=True),
106+
cpp.Field("c", "C"),
107+
cpp.Field("d", "D", is_repeated=True),
108+
cpp.Field("e", "E"),
109+
]
110+
c = cpp.Class("foo", fields=fields)
111+
assert c.single_fields == fields[::2]
112+
113+
59114
if __name__ == '__main__':
60-
sys.exit(pytest.main())
115+
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)