diff --git a/ser/codegen.h b/ser/codegen.h index dc2d3cd..b635215 100644 --- a/ser/codegen.h +++ b/ser/codegen.h @@ -8,6 +8,11 @@ #include #include +#define INDENT 4 + +#define MSG_MAGIC_START 0xCAFEF00DBEEFDEADUL +#define MSG_MAGIC_END 0xF00DBEEFCAFEDEADUL + // Struct used to define the relative alignment when working with structs typedef struct { Alignment align; diff --git a/ser/codegen_c.c b/ser/codegen_c.c index 3ff7609..a4d4d0c 100644 --- a/ser/codegen_c.c +++ b/ser/codegen_c.c @@ -5,8 +5,6 @@ #include -#define INDENT 4 - typedef enum { MTPointer, MTArray, @@ -34,9 +32,9 @@ static inline const char *array_size_type(uint64_t size) { } } -void write_field(Writer *w, Field f, Modifier *mods, size_t len, uint32_t indent); +static void write_field(Writer *w, Field f, Modifier *mods, size_t len, uint32_t indent); // Wrte the *base* type type with indentation -void write_type(Writer *w, TypeObject *type, uint32_t indent) { +static void write_type(Writer *w, TypeObject *type, uint32_t indent) { if (type->kind == TypePrimitif) { #define _case(x, s) \ case Primitif_##x: \ @@ -145,7 +143,7 @@ void write_type(Writer *w, TypeObject *type, uint32_t indent) { // 7. // result = "char bar" -void write_field(Writer *w, Field f, Modifier *mods, size_t len, uint32_t indent) { +static void write_field(Writer *w, Field f, Modifier *mods, size_t len, uint32_t indent) { TypeObject *type = f.type; ModifierVec modifiers = vec_init(); TypeObject *current = type; @@ -185,7 +183,7 @@ void write_field(Writer *w, Field f, Modifier *mods, size_t len, uint32_t indent _vec_modifier_drop(modifiers); } -void write_struct(Writer *w, StructObject *obj, void *_user_data) { +static void write_struct(Writer *w, StructObject *obj, void *_user_data) { wt_format(w, "typedef struct %.*s {\n", obj->name.len, obj->name.ptr); for (size_t i = 0; i < obj->fields.len; i++) { @@ -197,11 +195,11 @@ void write_struct(Writer *w, StructObject *obj, void *_user_data) { wt_format(w, "} %.*s;\n\n", obj->name.len, obj->name.ptr); } -void write_align(Writer *w, const char *var, const Alignment align, size_t indent) { +static void write_align(Writer *w, const char *var, const Alignment align, size_t indent) { wt_format(w, "%*s%s = (byte*)(((((uintptr_t)%s - 1) >> %u) + 1) << %u);\n", indent, "", var, var, align.po2, align.po2); } -void write_accessor(Writer *w, TypeObject *base_type, FieldAccessor fa, bool ptr) { +static void write_accessor(Writer *w, TypeObject *base_type, FieldAccessor fa, bool ptr) { if (fa.indices.len == 0) return; @@ -251,7 +249,7 @@ void write_accessor(Writer *w, TypeObject *base_type, FieldAccessor fa, bool ptr } } -bool is_field_accessor_heap_array(FieldAccessor fa, TypeObject *base_type) { +static bool is_field_accessor_heap_array(FieldAccessor fa, TypeObject *base_type) { if (fa.indices.len == 0) return base_type->kind == TypeArray && base_type->type.array.heap; @@ -282,7 +280,7 @@ bool is_field_accessor_heap_array(FieldAccessor fa, TypeObject *base_type) { return t->kind == TypeArray && t->type.array.heap; } -void write_type_serialization( +static void write_type_serialization( Writer *w, const char *base, bool ptr, Layout *layout, CurrentAlignment al, Hashmap *layouts, size_t indent, size_t depth, bool always_inline ) { if (layout->fields.len == 0) @@ -357,7 +355,7 @@ void write_type_serialization( } } -void write_type_deserialization( +static void write_type_deserialization( Writer *w, const char *base, bool ptr, Layout *layout, CurrentAlignment al, Hashmap *layouts, size_t indent, size_t depth, bool always_inline ) { if (layout->fields.len == 0) @@ -447,7 +445,7 @@ void write_type_deserialization( } } -int write_type_free(Writer *w, const char *base, TypeObject *type, Hashmap *layouts, size_t indent, size_t depth, bool always_inline) { +static int write_type_free(Writer *w, const char *base, TypeObject *type, Hashmap *layouts, size_t indent, size_t depth, bool always_inline) { if (type->kind == TypePrimitif) { return 0; } else if (type->kind == TypeArray) { @@ -512,7 +510,7 @@ int write_type_free(Writer *w, const char *base, TypeObject *type, Hashmap *layo return 0; } -void write_struct_func_decl(Writer *w, StructObject *obj, void *_user_data) { +static void write_struct_func_decl(Writer *w, StructObject *obj, void *_user_data) { obj->has_funcs = true; StringSlice sname = obj->name; @@ -523,7 +521,7 @@ void write_struct_func_decl(Writer *w, StructObject *obj, void *_user_data) { free(snake_case_name); } -void write_struct_func(Writer *w, StructObject *obj, void *user_data) { +static void write_struct_func(Writer *w, StructObject *obj, void *user_data) { Hashmap *layouts = user_data; // Retreive original TypeObject pointer from struct object pointer. TypeObject *t = (void *)((byte *)obj - offsetof(struct TypeObject, type)); @@ -562,7 +560,7 @@ void codegen_c(Writer *header, Writer *source, const char *name, Program *p) { char *uc_name = snake_case_to_screaming_snake_case((StringSlice){.ptr = name, .len = strlen(name)}); wt_format( header, - "// Generated file, do not edit (its not like it'll explode if you do, but its better not to)\n" + "// Generated file\n" "#ifndef %s_H\n" "#define %s_H\n" "#include \n" @@ -573,16 +571,18 @@ void codegen_c(Writer *header, Writer *source, const char *name, Program *p) { "typedef uint64_t MsgMagic;\n" "\n" "#define MSG_MAGIC_SIZE sizeof(MsgMagic)\n" - "static const MsgMagic MSG_MAGIC_START = 0xCAFEF00DBEEFDEAD;\n" - "static const MsgMagic MSG_MAGIC_END = 0xF00DBEEFCAFEDEAD;\n" + "static const MsgMagic MSG_MAGIC_START = 0x%016lX;\n" + "static const MsgMagic MSG_MAGIC_END = 0x%016lX;\n" "\n", uc_name, - uc_name + uc_name, + MSG_MAGIC_START, + MSG_MAGIC_END ); free(uc_name); wt_format( source, - "// Generated file, do not edit (its not like it'll explode if you do, but its better not to)\n" + "// Generated file\n" "#include \"%s.h\"\n" "#include \n" "\n", diff --git a/ser/codegen_python.c b/ser/codegen_python.c new file mode 100644 index 0000000..1223b49 --- /dev/null +++ b/ser/codegen_python.c @@ -0,0 +1,537 @@ +#include "codegen_python.h" + +#include + +static void write_type_name(Writer *w, TypeObject *type, Hashmap *defined) { + if (type->kind == TypePrimitif) { +#define _case(x, str) \ + case Primitif_##x: \ + wt_format(w, str); \ + return + switch (type->type.primitif) { + _case(u8, "int"); + _case(u16, "int"); + _case(u32, "int"); + _case(u64, "int"); + _case(i8, "int"); + _case(i16, "int"); + _case(i32, "int"); + _case(i64, "int"); + _case(f32, "float"); + _case(f64, "float"); + _case(bool, "bool"); + _case(char, "str"); + } +#undef _case + } else if (type->kind == TypeArray) { + // If we have an array of char + if (type->type.array.type->kind == TypePrimitif && type->type.array.type->type.primitif == Primitif_char) { + wt_format(w, "str"); + return; + } + wt_format(w, "List["); + write_type_name(w, type->type.array.type, defined); + wt_format(w, "]"); + } else if (type->kind == TypeStruct) { + StructObject *obj = (StructObject *)&type->type.struct_; + if (defined == NULL || hashmap_has(defined, &obj)) { + wt_write(w, obj->name.ptr, obj->name.len); + } else { + wt_format(w, "'%.*s'", obj->name.len, obj->name.ptr); + } + } +} + +typedef enum { + Read, + Write, +} Access; + +static void write_field_accessor(Writer *w, const char *base, FieldAccessor fa, TypeObject *type, Access access) { + if (fa.indices.len == 0) { + wt_format(w, "%s", base); + return; + } + + BufferedWriter b = buffered_writer_init(); + Writer *w2 = (Writer *)&b; + + bool len = false; + wt_format(w2, "%s", base); + for (size_t i = 0; i < fa.indices.len; i++) { + uint64_t index = fa.indices.data[i]; + + if (type->kind == TypeStruct) { + StructObject *obj = (StructObject *)&type->type.struct_; + Field f = obj->fields.data[index]; + wt_format(w2, ".%.*s", f.name.len, f.name.ptr); + + type = f.type; + } else if (type->kind == TypeArray) { + if (type->type.array.sizing == SizingMax) { + len = index == 0; + break; + } else { + wt_format(w2, "[%lu]", index); + type = type->type.array.type; + } + } + } + + if (len) { + if (access == Read) { + wt_format(w, "len(%.*s)", b.buf.len, b.buf.data); + } else { + CharVec buf = b.buf; + for (size_t i = 0; i < buf.len; i++) { + if (buf.data[i] == '.') { + buf.data[i] = '_'; + } + } + wt_write(w, buf.data, buf.len); + wt_format(w, "_len"); + } + } else { + wt_write(w, b.buf.data, b.buf.len); + } + + buffered_writer_drop(b); +} + +static bool field_accessor_is_array_length(FieldAccessor fa, TypeObject *type) { + if (fa.indices.len == 0 || fa.indices.data[fa.indices.len - 1] != 0 || fa.type->kind != TypePrimitif) + return false; + + for (size_t i = 0; i < fa.indices.len - 1; i++) { + uint64_t index = fa.indices.data[i]; + if (type->kind == TypeStruct) { + StructObject *s = (StructObject *)&type->type.struct_; + + type = s->fields.data[index].type; + } else if (type->kind == TypeArray) { + type = type->type.array.type; + } + } + + return type->kind == TypeArray && type->type.array.sizing == SizingMax; +} + +// Find the index of the FieldAccessor that points to the length of this array +static uint64_t get_array_length_field_index(Layout *layout, FieldAccessor fa) { + size_t index = SIZE_MAX; + for (size_t j = 0; j < layout->fields.len && layout->fields.data[j].size != 0; j++) { + FieldAccessor f = layout->fields.data[j]; + if (f.indices.len != fa.indices.len) + continue; + + // Check the indices for equality, all but the last one should be equal + bool equal = true; + for (size_t k = 0; k < f.indices.len - 1; k++) { + if (f.indices.data[k] != fa.indices.data[k]) { + equal = false; + break; + } + } + + if (equal && f.indices.data[f.indices.len - 1] == 0) { + index = j; + break; + } + } + + if (index == SIZE_MAX) { + log_error("No length accessor for variable size array accessor"); + exit(1); + } + + return index; +} + +static void write_type_uninit(Writer *u, TypeObject *type) { + if (type->kind == TypeStruct) { + StructObject *s = (StructObject *)&type->type.struct_; + for (size_t i = 0; i < s->fields.len; i++) { + Field f = s->fields.data[i]; + if (f.type->kind == TypeStruct) { + StringSlice tname = f.type->type.struct_.name; + wt_format(u, "%*sself.%.*s = %.*s.uninit()\n", INDENT * 2, "", f.name.len, f.name.ptr, tname.len, tname.ptr); + } else if (f.type->kind == TypeArray) { + wt_format(u, "%*sself.%.*s = []\n", INDENT * 2, "", f.name.len, f.name.ptr); + } + } + } +} + +static void write_type_funcs( + Writer *s, + Writer *d, + TypeObject *type, + const char *base, + CurrentAlignment al, + Hashmap *layouts, + size_t indent, + size_t depth, + bool always_inline +) { + Layout *layout = hashmap_get(layouts, &(Layout){.type = type}); + assert(layout != NULL, "Type has no layout"); + + if (layout->fields.len == 0) + return; + + Alignment align = al.align; + size_t offset = calign_to(al, layout->fields.data[0].type->align); + + if (offset > 0) { + wt_format(s, "%*sbuf += bytes(%lu)\n", indent, "", offset); + wt_format(d, "%*soff += %lu\n", indent, "", offset); + } + + if (type->kind == TypeStruct && !always_inline) { + offset = calign_to(al, layout->type->align); + if (offset != 0) { + wt_format(s, "%*sbuf += bytes(%lu)\n", indent, "", offset); + wt_format(d, "%*soff = %lu\n", indent, "", offset); + } + wt_format(s, "%*s%s.serialize(buf)\n", indent, "", base); + wt_format(d, "%*soff += %s.deserialize(buf[off:])\n", indent, "", base); + return; + } + + wt_format(s, "%*sbuf += pack('<", indent, ""); + wt_format(d, "%*sxs%lu = unpack('<", indent, "", depth); + al = calign_add(al, offset); + + size_t size = 0; + size_t i = 0; + for (; i < layout->fields.len && layout->fields.data[i].size != 0; i++) { + FieldAccessor fa = layout->fields.data[i]; + assert(fa.type->kind == TypePrimitif, "Field accessor of non zero size doesn't point to primitive type"); + +#define _case(x, f) \ + case Primitif_##x: \ + wt_format(s, f); \ + wt_format(d, f); \ + break + switch (fa.type->type.primitif) { + _case(u8, "B"); + _case(u16, "H"); + _case(u32, "I"); + _case(u64, "Q"); + _case(i8, "b"); + _case(i16, "h"); + _case(i32, "i"); + _case(i64, "q"); + _case(f32, "f"); + _case(f64, "d"); + _case(bool, "?"); + _case(char, "c"); + } +#undef _case + al = calign_add(al, fa.size); + offset += fa.size; + size += fa.size; + } + + size_t padding = 0; + if (i < layout->fields.len) { + padding = calign_to(al, layout->fields.data[i].type->align); + wt_format(s, "%lux", padding); + } + + wt_format(s, "',\n"); + wt_format(d, "', buf[off:off + %lu])\n", size); + + for (size_t j = 0; j < i; j++) { + FieldAccessor fa = layout->fields.data[j]; + + wt_format(s, "%*s ", indent, ""); + write_field_accessor(s, base, fa, type, Read); + if (fa.type->kind == TypePrimitif && fa.type->type.primitif == Primitif_char) { + wt_format(s, ".encode(encoding='ASCII', errors='replace')"); + } + if (j < i - 1) { + wt_format(s, ",\n"); + } else { + wt_format(s, ")\n"); + } + + if (!field_accessor_is_array_length(fa, type)) { + wt_format(d, "%*s", indent, ""); + write_field_accessor(d, base, fa, type, Write); + wt_format(d, " = xs%lu[%lu]", depth, j); + if (fa.type->kind == TypePrimitif && fa.type->type.primitif == Primitif_char) { + wt_format(d, ".decode(encoding='ASCII', errors='replace')"); + } + wt_format(d, "\n"); + } + } + + if (i < layout->fields.len) { + wt_format(d, "%*soff += %lu\n", indent, "", padding + size); + } else { + wt_format(d, "%*soff += %lu\n", indent, "", padding + size + calign_to(al, align)); + } + + bool alignment_unknown = false; + for (; i < layout->fields.len; i++) { + alignment_unknown = true; + + FieldAccessor fa = layout->fields.data[i]; + uint64_t len_index = get_array_length_field_index(layout, fa); + + if (fa.type->kind == TypePrimitif && fa.type->type.primitif == Primitif_char) { + wt_format(s, "%*sbuf += ", indent, ""); + wt_format(d, "%*s", indent, ""); + write_field_accessor(s, base, fa, type, Read); + write_field_accessor(d, base, fa, type, Write); + + wt_format(d, " = buf[off:off + xs%lu[%lu]].decode(encoding='ASCII', errors='replace')\n", depth, len_index); + wt_format(d, "%*soff += xs%lu[%lu]\n", indent, "", depth, len_index); + wt_format(s, ".encode(encoding='ASCII', errors='replace')\n"); + continue; + } + + wt_format(s, "%*sfor e%lu in ", indent, "", depth); + write_field_accessor(s, base, fa, type, Read); + wt_format(s, ":\n"); + wt_format(d, "%*s", indent, ""); + write_field_accessor(d, base, fa, type, Write); + wt_format(d, " = []\n"); + wt_format(d, "%*sfor _ in range(xs%lu[%lu]):\n", indent, "", depth, len_index); + if (fa.type->kind == TypeArray && fa.type->type.array.sizing == SizingFixed) { + wt_format(d, "%*se%lu = [None] * %lu\n", indent + INDENT, "", depth, fa.type->type.array.size); + } else if (fa.type->kind == TypeStruct) { + struct StructObject s = fa.type->type.struct_; + wt_format(d, "%*se%lu = %.*s.uninit()\n", indent + INDENT, "", depth, s.name.len, s.name.ptr); + } + char *new_base = msprintf("e%lu", depth); + write_type_funcs( + s, + d, + fa.type, + new_base, + (CurrentAlignment){.align = fa.type->align, .offset = 0}, + layouts, + indent + INDENT, + depth + 1, + false + ); + wt_format(d, "%*s", indent + INDENT, ""); + write_field_accessor(d, base, fa, type, Write); + wt_format(d, ".append(%s)\n", new_base); + free(new_base); + } + + if (alignment_unknown) { + wt_format(s, "%*sbuf += bytes((%u - len(buf)) & %u)\n", indent, "", align.value, align.mask); + wt_format(d, "%*soff += (%u - off) & %u\n", indent, "", align.value, align.mask); + } +} + +static void write_struct_class(Writer *w, StructObject *obj, Hashmap *defined, Hashmap *layouts) { + TypeObject *type = (void *)((byte *)obj - offsetof(TypeObject, type)); + + wt_format(w, "@dataclass\n"); + wt_format(w, "class %.*s:\n", obj->name.len, obj->name.ptr); + for (size_t i = 0; i < obj->fields.len; i++) { + Field f = obj->fields.data[i]; + wt_format(w, "%*s%.*s: ", INDENT, "", f.name.len, f.name.ptr); + write_type_name(w, f.type, defined); + wt_format(w, "\n"); + } + BufferedWriter ser = buffered_writer_init(); + BufferedWriter deser = buffered_writer_init(); + write_type_funcs( + (Writer *)&ser, + (Writer *)&deser, + type, + "self", + (CurrentAlignment){.align = type->align, .offset = 0}, + layouts, + INDENT * 2, + 0, + true + ); + + wt_format(w, "%*s\n", INDENT, ""); + wt_format(w, "%*s@classmethod\n", INDENT, ""); + wt_format(w, "%*sdef uninit(cls) -> '%.*s':\n", INDENT, "", obj->name.len, obj->name.ptr); + wt_format(w, "%*sself = cls.__new__(cls)\n", INDENT * 2, ""); + write_type_uninit(w, type); + wt_format(w, "%*sreturn self\n", INDENT * 2, ""); + wt_format(w, "%*s\n", INDENT, ""); + wt_format(w, "%*sdef serialize(self, buf: bytearray):\n", INDENT, ""); + wt_format(w, "%*sbase = len(buf)\n", INDENT * 2, ""); + wt_write(w, ser.buf.data, ser.buf.len); + wt_format(w, "%*sreturn len(buf) - base\n", INDENT * 2, ""); + wt_format(w, "%*s\n", INDENT, ""); + wt_format(w, "%*sdef deserialize(self, buf: bytes):\n", INDENT, ""); + wt_format(w, "%*soff = 0\n", INDENT * 2, ""); + wt_write(w, deser.buf.data, deser.buf.len); + wt_format(w, "%*sreturn off\n", INDENT * 2, ""); + wt_format(w, "\n"); + + buffered_writer_drop(ser); + buffered_writer_drop(deser); + + hashmap_set(defined, &obj); +} + +typedef struct { + Hashmap *layouts; + Hashmap *defined; +} CallbackData; + +static void write_struct(Writer *w, StructObject *obj, void *user_data) { + CallbackData *data = user_data; + write_struct_class(w, obj, data->defined, data->layouts); +} + +static void define_struct_classes(Writer *w, Program *p) { + Hashmap *defined = hashmap_init(pointer_hash, pointer_equal, NULL, sizeof(StructObject *)); + CallbackData data = {.defined = defined, .layouts = p->layouts}; + define_structs(p, w, write_struct, &data); + hashmap_drop(defined); +} + +static void define_message(Writer *w, const char *prefix, uint16_t tag, Hashmap *layouts, MessageObject msg, uint64_t version) { + char *name = msprintf("%s%.*s", prefix, msg.name.len, msg.name.ptr); + StringSlice name_slice = {.ptr = name, .len = strlen(name)}; + + wt_format(w, "@dataclass\n"); + wt_format(w, "class %s(%sMessage):\n", name, prefix); + + TypeObject *type; + StructObject *obj; + FieldVec fields = vec_clone(&msg.fields); + { + if (msg.attributes & Attr_versioned) { + Field f = {.name = STRING_SLICE("_version"), .type = (TypeObject *)&PRIMITIF_u64}; + vec_push(&fields, f); + } + + type = malloc(sizeof(TypeObject)); + assert_alloc(type); + type->kind = TypeStruct; + type->type.struct_.name = name_slice; + type->type.struct_.has_funcs = false; + type->type.struct_.fields = *(AnyVec *)&fields; + type->align = ALIGN_8; + obj = (StructObject *)&type->type.struct_; + + Layout l = type_layout(type); + hashmap_set(layouts, &l); + } + + for (size_t i = 0; i < msg.fields.len; i++) { + Field f = msg.fields.data[i]; + wt_format(w, "%*s%.*s: ", INDENT, "", f.name.len, f.name.ptr); + write_type_name(w, f.type, NULL); + wt_format(w, "\n"); + } + + if (msg.attributes & Attr_versioned) { + wt_format(w, "%*s_version: int = %lu\n", INDENT, "", version); + } + + BufferedWriter ser = buffered_writer_init(); + BufferedWriter deser = buffered_writer_init(); + write_type_funcs( + (Writer *)&ser, + (Writer *)&deser, + type, + "self", + (CurrentAlignment){.align = type->align, .offset = 2}, + layouts, + INDENT * 2, + 0, + true + ); + + wt_format(w, "%*s\n", INDENT, ""); + wt_format(w, "%*s@classmethod\n", INDENT, ""); + wt_format(w, "%*sdef uninit(cls) -> '%s':\n", INDENT, "", name); + wt_format(w, "%*sself = cls.__new__(cls)\n", INDENT * 2, ""); + write_type_uninit(w, type); + wt_format(w, "%*sreturn self\n", INDENT * 2, ""); + wt_format(w, "%*s\n", INDENT, ""); + wt_format(w, "%*sdef serialize(self, buf: bytearray):\n", INDENT, ""); + wt_format(w, "%*sbase = len(buf)\n", INDENT * 2, ""); + wt_format(w, "%*sbuf += pack('>QH', MSG_MAGIC_START, %u)\n", INDENT * 2, "", tag); + wt_write(w, ser.buf.data, ser.buf.len); + wt_format(w, "%*sbuf += pack('>Q', MSG_MAGIC_END)\n", INDENT * 2, ""); + wt_format(w, "%*sreturn len(buf) - base\n", INDENT * 2, ""); + wt_format(w, "%*s\n", INDENT, ""); + wt_format(w, "%*s@classmethod\n", INDENT, ""); + wt_format(w, "%*sdef _deserialize(cls, buf: bytes) -> Tuple['%s', int]:\n", INDENT, "", name); + wt_format(w, "%*smagic_start, tag = unpack('>QH', buf[0:10])\n", INDENT * 2, ""); + wt_format(w, "%*sif magic_start != MSG_MAGIC_START or tag != %u:\n", INDENT * 2, "", tag); + wt_format(w, "%*sraise ValueError\n", INDENT * 3, ""); + wt_format(w, "%*soff = 10\n", INDENT * 2, ""); + wt_format(w, "%*sself = %s.uninit()\n", INDENT * 2, "", name); + wt_write(w, deser.buf.data, deser.buf.len); + wt_format(w, "%*smagic_end = unpack('>Q', buf[off:off + 8])[0]\n", INDENT * 2, ""); + wt_format(w, "%*sif magic_end != MSG_MAGIC_END:\n", INDENT * 2, ""); + wt_format(w, "%*sraise ValueError\n", INDENT * 3, ""); + wt_format(w, "%*soff += 8\n", INDENT * 2, ""); + wt_format(w, "%*sreturn self, off\n", INDENT * 2, ""); + wt_format(w, "\n"); + + buffered_writer_drop(ser); + buffered_writer_drop(deser); + free(name); + vec_drop(fields); + free(type); +} + +static void define_messages(Writer *w, MessagesObject msgs, Program *p) { + char *prefix = strndup(msgs.name.ptr, msgs.name.len); + wt_format(w, "class %sMessage(ABC):\n", prefix); + wt_format(w, "%*s@abstractmethod\n", INDENT, ""); + wt_format(w, "%*sdef serialize(self, buf: bytearray) -> int:\n", INDENT, ""); + wt_format(w, "%*spass\n", INDENT * 2, ""); + wt_format(w, "%*s@classmethod\n", INDENT, ""); + wt_format(w, "%*sdef deserialize(cls, buf: bytes) -> Tuple['Message', int]:\n", INDENT, ""); + wt_format(w, "%*smagic_start, tag = unpack('>QH', buf[0:10])\n", INDENT * 2, ""); + wt_format(w, "%*sif magic_start != MSG_MAGIC_START:\n", INDENT * 2, ""); + wt_format(w, "%*sraise ValueError\n", INDENT * 3, ""); + for (size_t i = 0; i < msgs.messages.len; i++) { + if (i == 0) { + wt_format(w, "%*sif tag == 0:\n", INDENT * 2, ""); + } else { + wt_format(w, "%*selif tag == %lu:\n", INDENT * 2, "", i); + } + StringSlice name = msgs.messages.data[i].name; + wt_format(w, "%*sreturn %s%.*s._deserialize(buf)\n", INDENT * 3, "", prefix, name.len, name.ptr); + } + wt_format(w, "%*selse:\n", INDENT * 2, ""); + wt_format(w, "%*sraise ValueError\n", INDENT * 3, ""); + + for (size_t i = 0; i < msgs.messages.len; i++) { + define_message(w, prefix, i, p->layouts, msgs.messages.data[i], msgs.version); + } + free(prefix); +} + +void codegen_python(Writer *source, Program *p) { + wt_format( + source, + "# generated file\n" + "from dataclasses import dataclass\n" + "from typing import List, Tuple\n" + "from abc import ABC, abstractmethod\n" + "from struct import pack, unpack\n" + "\n" + "MSG_MAGIC_START = 0x%016lX\n" + "MSG_MAGIC_END = 0x%016lX\n" + "\n", + MSG_MAGIC_START, + MSG_MAGIC_END + ); + + define_struct_classes(source, p); + for (size_t i = 0; i < p->messages.len; i++) { + define_messages(source, p->messages.data[i], p); + } +} diff --git a/ser/codegen_python.h b/ser/codegen_python.h new file mode 100644 index 0000000..8067070 --- /dev/null +++ b/ser/codegen_python.h @@ -0,0 +1,8 @@ +#ifndef CODEGEN_PYTHON_H +#define CODEGEN_PYTHON_H + +#include "codegen.h" + +void codegen_python(Writer *source, Program *p); + +#endif diff --git a/ser/macro_utils.h b/ser/macro_utils.h index 45c3553..9254270 100644 --- a/ser/macro_utils.h +++ b/ser/macro_utils.h @@ -22,6 +22,7 @@ #define SND(a, b, ...) b #define FST(a, ...) a #define CAT(a, b) a##b +#define STR(a) #a #define PROBE() ~, 1 #define IS_PROBE(...) SND(__VA_ARGS__, 0) diff --git a/ser/main.c b/ser/main.c index 183cea7..908cb40 100644 --- a/ser/main.c +++ b/ser/main.c @@ -1,5 +1,6 @@ #include "ast.h" #include "codegen_c.h" +#include "codegen_python.h" #include "hashmap.h" #include "lexer.h" #include "log.h" @@ -16,15 +17,31 @@ void abort_error(uint32_t error_count) { typedef enum { BackendC, + BackendPython, } Backend; +static Hashmap *backend_map = NULL; + +typedef struct { + StringSlice name; + Backend b; +} BackendString; + +impl_hashmap_delegate(backend, BackendString, string_slice, name); + Backend parse_backend(const char *b) { - if (strcmp(b, "c") == 0) { - return BackendC; - } else { - log_error("Couldn't parse requested backend: got %s expected one of 'c'.", b); + if (backend_map == NULL) { + backend_map = hashmap_init(backend_hash, backend_equal, NULL, sizeof(BackendString)); + hashmap_set(backend_map, &(BackendString){.name = STRING_SLICE("c"), .b = BackendC}); + hashmap_set(backend_map, &(BackendString){.name = STRING_SLICE("python"), .b = BackendPython}); + } + + BackendString *backend = hashmap_get(backend_map, &(BackendString){.name.ptr = b, .name.len = strlen(b)}); + if (backend == NULL) { + log_error("Unknown backend '%s'", b); exit(1); } + return backend->b; } int main(int argc, char **argv) { @@ -102,10 +119,25 @@ int main(int argc, char **argv) { free(header_path); break; } + case BackendPython: { + FileWriter source = file_writer_init(output); + + codegen_python((Writer *)&source, &evaluation_result.program); + + file_writer_drop(source); + break; + } + default: + log_error("What the fuck ?"); + exit(1); } program_drop(evaluation_result.program); ast_drop(parsing_result.ctx); vec_drop(lexing_result.tokens); source_drop(src); + + if (backend_map != NULL) { + hashmap_drop(backend_map); + } }