diff --git a/Makefile b/Makefile index fbb3d45..7f8b4f3 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ Q=@ CC=gcc GCCCFLAGS=-Wno-format-truncation -CFLAGS=-std=c11 -pedantic -g -Wall -pthread -D_GNU_SOURCE +CFLAGS=-std=gnu11 -pedantic -g -Wall -pthread -D_GNU_SOURCE -fsanitize=address LDFLAGS=-lm # The root for /sys and /dev needs to be moved in docker, this should stay empty in most cases diff --git a/client.c b/client.c index 726c8d7..5a3d212 100644 --- a/client.c +++ b/client.c @@ -37,13 +37,13 @@ static struct pollfd *socket_poll = &poll_fds[1]; static int fifo = -1; static int sock = -1; // static to avoid having this on the stack because a message is about 2kb in memory -static Message message; +static DeviceMessage message; static Vec devices_fd; static Vec devices_info; -static ClientConfig config; -static MessageRequest device_request; +static ClientConfig config; +static DeviceRequest device_request; static void default_fifo_path(void *ptr) { *(char **)ptr = (char *)FIFO_PATH; } static void default_retry_delay(void *ptr) { *(struct timespec *)ptr = CONNECTION_RETRY_DELAY; } @@ -58,17 +58,17 @@ static void default_to_white(void *ptr) { } static const JSONPropertyAdapter ControllerStateAdapterProps[] = { - {".led_color", &StringAdapter, offsetof(MessageControllerState, led), default_to_white, tsf_hex_to_color }, - {".rumble.0", &NumberAdapter, offsetof(MessageControllerState, small_rumble), default_to_zero_u8, tsf_num_to_u8_clamp}, - {".rumble.1", &NumberAdapter, offsetof(MessageControllerState, big_rumble), default_to_zero_u8, tsf_num_to_u8_clamp}, - {".flash.0", &NumberAdapter, offsetof(MessageControllerState, flash_on), default_to_zero_u8, tsf_num_to_u8_clamp}, - {".flash.1", &NumberAdapter, offsetof(MessageControllerState, flash_off), default_to_zero_u8, tsf_num_to_u8_clamp}, - {".index", &NumberAdapter, offsetof(MessageControllerState, index), default_to_zero_u32, tsf_num_to_int } + {".led_color", &StringAdapter, offsetof(DeviceControllerState, led), default_to_white, tsf_hex_to_color }, + {".rumble.0", &NumberAdapter, offsetof(DeviceControllerState, small_rumble), default_to_zero_u8, tsf_num_to_u8_clamp}, + {".rumble.1", &NumberAdapter, offsetof(DeviceControllerState, big_rumble), default_to_zero_u8, tsf_num_to_u8_clamp}, + {".flash.0", &NumberAdapter, offsetof(DeviceControllerState, flash_on), default_to_zero_u8, tsf_num_to_u8_clamp}, + {".flash.1", &NumberAdapter, offsetof(DeviceControllerState, flash_off), default_to_zero_u8, tsf_num_to_u8_clamp}, + {".index", &NumberAdapter, offsetof(DeviceControllerState, index), default_to_zero_u32, tsf_num_to_int } }; static const JSONAdapter ControllerStateAdapter = { .props = (JSONPropertyAdapter *)ControllerStateAdapterProps, .prop_count = sizeof(ControllerStateAdapterProps) / sizeof(JSONPropertyAdapter), - .size = sizeof(MessageControllerState), + .size = sizeof(DeviceControllerState), }; static const JSONPropertyAdapter ControllerAdapterProps[] = { @@ -127,12 +127,12 @@ static void print_config() { void destroy_devices(void) { for (int i = 0; i < config.slot_count; i++) { - int fd = *(int *)vec_get(&devices_fd, i); - MessageDeviceInfo *info = vec_get(&devices_info, i); + int fd = *(int *)vec_get(&devices_fd, i); + DeviceInfo *info = vec_get(&devices_info, i); - if (info->code == DeviceInfo) { + if (info->tag == DeviceTagInfo) { ioctl(fd, UI_DEV_DESTROY); - info->code = NoMessage; + info->tag = DeviceTagNone; } } } @@ -142,8 +142,8 @@ bool device_exists(int index) { return false; } - MessageDeviceInfo *info = vec_get(&devices_info, index); - return info->code == DeviceInfo; + DeviceInfo *info = vec_get(&devices_info, index); + return info->tag == DeviceTagInfo; } void device_destroy(int slot) { @@ -153,15 +153,15 @@ void device_destroy(int slot) { int fd = *(int *)vec_get(&devices_fd, slot); - MessageDeviceInfo *info = vec_get(&devices_info, slot); + DeviceInfo *info = vec_get(&devices_info, slot); - if (info->code == DeviceInfo) { + if (info->tag == DeviceTagInfo) { ioctl(fd, UI_DEV_DESTROY); - info->code = NoMessage; + info->tag = DeviceTagNone; } } -void device_init(MessageDeviceInfo *dev) { +void device_init(DeviceInfo *dev) { if (dev->slot >= devices_info.len) { printf("CLIENT: Got wrong device index\n"); return; @@ -172,35 +172,36 @@ void device_init(MessageDeviceInfo *dev) { int fd = *(int *)vec_get(&devices_fd, dev->slot); // Abs - if (dev->abs_count > 0) { + if (dev->abs.len > 0) { ioctl(fd, UI_SET_EVBIT, EV_ABS); - for (int i = 0; i < dev->abs_count; i++) { + for (int i = 0; i < dev->abs.len; i++) { struct uinput_abs_setup setup = {0}; - setup.code = dev->abs_id[i]; - setup.absinfo.minimum = dev->abs_min[i]; - setup.absinfo.maximum = dev->abs_max[i]; - setup.absinfo.fuzz = dev->abs_fuzz[i]; - setup.absinfo.flat = dev->abs_flat[i]; - setup.absinfo.resolution = dev->abs_res[i]; + Abs abs = dev->abs.data[i]; + setup.code = abs.id; + setup.absinfo.minimum = abs.min; + setup.absinfo.maximum = abs.max; + setup.absinfo.fuzz = abs.fuzz; + setup.absinfo.flat = abs.flat; + setup.absinfo.resolution = abs.res; setup.absinfo.value = 0; ioctl(fd, UI_ABS_SETUP, &setup); } } // Rel - if (dev->rel_count > 0) { + if (dev->rel.len > 0) { ioctl(fd, UI_SET_EVBIT, EV_REL); - for (int i = 0; i < dev->rel_count; i++) { - ioctl(fd, UI_SET_RELBIT, dev->rel_id[i]); + for (int i = 0; i < dev->rel.len; i++) { + ioctl(fd, UI_SET_RELBIT, dev->rel.data[i].id); } } // Key - if (dev->key_count > 0) { + if (dev->key.len > 0) { ioctl(fd, UI_SET_EVBIT, EV_KEY); - for (int i = 0; i < dev->key_count; i++) { - ioctl(fd, UI_SET_KEYBIT, dev->key_id[i]); + for (int i = 0; i < dev->key.len; i++) { + ioctl(fd, UI_SET_KEYBIT, dev->key.data[i].id); } } @@ -217,11 +218,11 @@ void device_init(MessageDeviceInfo *dev) { ioctl(fd, UI_DEV_SETUP, &setup); ioctl(fd, UI_DEV_CREATE); - MessageDeviceInfo *dst = vec_get(&devices_info, dev->slot); + DeviceInfo *dst = vec_get(&devices_info, dev->slot); - memcpy(dst, dev, sizeof(MessageDeviceInfo)); - printf("CLIENT: Got device [%d]: '%s' (abs: %d, rel: %d, key: %d)\n", dev->slot, ctr->device_name, dev->abs_count, - dev->rel_count, dev->key_count); + memcpy(dst, dev, sizeof(DeviceInfo)); + printf("CLIENT: Got device [%d]: '%s' (abs: %d, rel: %d, key: %d)\n", dev->slot, ctr->device_name, dev->abs.len, dev->rel.len, + dev->key.len); } // Send an event to uinput, device must exist @@ -241,33 +242,34 @@ bool device_emit(int index, uint16_t type, uint16_t id, uint32_t value) { } // Update device with report -void device_handle_report(MessageDeviceReport *report) { +void device_handle_report(DeviceReport *report) { if (!device_exists(report->slot)) { printf("CLIENT: [%d] Got report before device info\n", report->slot); return; } - MessageDeviceInfo *info = vec_get(&devices_info, report->slot); + DeviceInfo *info = vec_get(&devices_info, report->slot); - if (report->abs_count != info->abs_count || report->rel_count != info->rel_count || report->key_count != info->key_count) { - printf("CLIENT: Report doesn't match with device info\n"); + if (report->abs.len != info->abs.len || report->rel.len != info->rel.len || report->key.len != info->key.len) { + printf("CLIENT: Report doesn't match with device info (expected %u/%u/%u, got %u/%u/%u)\n", info->abs.len, info->rel.len, + info->key.len, report->abs.len, report->rel.len, report->key.len); return; } - for (int i = 0; i < report->abs_count; i++) { - if (device_emit(report->slot, EV_ABS, info->abs_id[i], report->abs[i]) != 0) { + for (int i = 0; i < report->abs.len; i++) { + if (device_emit(report->slot, EV_ABS, info->abs.data[i].id, report->abs.data[i]) != 0) { printf("CLIENT: Error writing abs event to uinput\n"); } } - for (int i = 0; i < report->rel_count; i++) { - if (device_emit(report->slot, EV_REL, info->rel_id[i], report->rel[i]) != 0) { + for (int i = 0; i < report->rel.len; i++) { + if (device_emit(report->slot, EV_REL, info->rel.data[i].id, report->rel.data[i]) != 0) { printf("CLIENT: Error writing rel event to uinput\n"); } } - for (int i = 0; i < report->key_count; i++) { - if (device_emit(report->slot, EV_KEY, info->key_id[i], (uint32_t)(!report->key[i]) - 1) != 0) { + for (int i = 0; i < report->key.len; i++) { + if (device_emit(report->slot, EV_KEY, info->key.data[i].id, (uint32_t)(!report->key.data[i]) - 1) != 0) { printf("CLIENT: Error writing key event to uinput\n"); } } @@ -278,10 +280,10 @@ void device_handle_report(MessageDeviceReport *report) { void setup_devices(void) { devices_fd = vec_of(int); - devices_info = vec_of(MessageDeviceInfo); + devices_info = vec_of(DeviceInfo); - MessageDeviceInfo no_info = {0}; - no_info.code = NoMessage; + DeviceInfo no_info = {0}; + no_info.tag = DeviceTagNone; for (int i = 0; i < config.slot_count; i++) { int fd = open(FSROOT "/dev/uinput", O_WRONLY | O_NONBLOCK); @@ -350,9 +352,9 @@ void connect_server(void) { socket_poll->fd = sock; printf("CLIENT: Connected !\n"); - uint8_t buf[2048] __attribute__((aligned(4))) = {0}; + uint8_t buf[2048] __attribute__((aligned(8))) = {0}; - int len = msg_serialize(buf, 2048, (Message *)&device_request); + int len = msg_device_serialize(buf, 2048, (DeviceMessage *)&device_request); if (len > 0) { if (send(sock, buf, len, 0) > 0) { printf("CLIENT: Sent device request\n"); @@ -382,20 +384,20 @@ void setup_server(char *address, uint16_t port) { } void build_device_request(void) { - TagList *reqs = malloc(config.slot_count * sizeof(TagList)); + device_request.tag = DeviceTagRequest; + device_request.requests.len = config.slot_count; + device_request.requests.data = malloc(config.slot_count * sizeof(TagList)); for (int i = 0; i < config.slot_count; i++) { - TagList *req = &reqs[i]; - req->count = config.slots[i].controller_count; - req->tags = malloc(req->count * sizeof(char *)); + TagList *list = &device_request.requests.data[i]; + list->tags.len = config.slots[i].controller_count; + list->tags.data = malloc(list->tags.len * sizeof(typeof(*list->tags.data))); - for (int j = 0; j < req->count; j++) { - req->tags[j] = config.slots[i].controllers[j].tag; + for (int j = 0; j < list->tags.len; j++) { + char *name = config.slots[i].controllers[j].tag; + list->tags.data[j].name.len = strlen(name); + list->tags.data[j].name.data = name; } } - - device_request.code = Request; - device_request.request_count = config.slot_count; - device_request.requests = reqs; } void client_run(char *address, uint16_t port, char *config_path) { @@ -431,10 +433,10 @@ void client_run(char *address, uint16_t port, char *config_path) { setup_devices(); setup_server(address, port); - uint8_t buf[2048] __attribute__((aligned(4))); + uint8_t buf[2048] __attribute__((aligned(8))); uint8_t json_buf[2048] __attribute__((aligned(8))); - while (1) { + while (true) { int rc = poll(poll_fds, 2, -1); if (rc < 0) { perror("CLIENT: Error on poll"); @@ -451,11 +453,11 @@ void client_run(char *address, uint16_t port, char *config_path) { if (rc < 0) { printf("CLIENT: Error when parsing fifo message as json (%s at index %lu)\n", json_strerr(), json_errloc()); } else { - MessageControllerState msg; - msg.code = ControllerState; + DeviceControllerState msg; + msg.tag = DeviceTagControllerState; json_adapt(json_buf, &ControllerStateAdapter, &msg); - int len = msg_serialize(buf, 2048, (Message *)&msg); + int len = msg_device_serialize(buf, 2048, (DeviceMessage *)&msg); if (len > 0) { if (send(sock, buf, len, 0) > 0) { printf("CLIENT: Sent controller state: #%02x%02x%02x flash: (%d, %d) rumble: " @@ -479,7 +481,7 @@ void client_run(char *address, uint16_t port, char *config_path) { continue; } - int msg_len = msg_deserialize(buf, len, &message); + int msg_len = msg_device_deserialize(buf, len, &message); // We've got data from the server if (msg_len < 0) { recv(sock, buf, 2048, 0); @@ -500,15 +502,15 @@ void client_run(char *address, uint16_t port, char *config_path) { recv(sock, buf, msg_len, 0); - if (message.code == DeviceInfo) { - if (device_exists(message.device_info.slot)) { + if (message.tag == DeviceTagInfo) { + if (device_exists(message.info.slot)) { printf("CLIENT: Got more than one device info for same device\n"); } - device_init((MessageDeviceInfo *)&message); - } else if (message.code == DeviceReport) { - device_handle_report((MessageDeviceReport *)&message); - } else if (message.code == DeviceDestroy) { + device_init((DeviceInfo *)&message); + } else if (message.tag == DeviceTagReport) { + device_handle_report((DeviceReport *)&message); + } else if (message.tag == DeviceTagDestroy) { device_destroy(message.destroy.index); printf("CLIENT: Lost device %d\n", message.destroy.index); } else { diff --git a/hid.c b/hid.c index 27c0050..da1ae42 100644 --- a/hid.c +++ b/hid.c @@ -55,10 +55,10 @@ uniq_t parse_uniq(char uniq[17]) { // Finish setup of a partially initialized device (set device_info and mapping) void setup_device(PhysicalDevice *dev) { - dev->device_info.code = DeviceInfo; - dev->device_info.abs_count = 0; - dev->device_info.rel_count = 0; - dev->device_info.key_count = 0; + dev->device_info.tag = DeviceTagInfo; + dev->device_info.abs.len = 0; + dev->device_info.rel.len = 0; + dev->device_info.key.len = 0; for (int i = 0; i < ABS_CNT; i++) dev->mapping.abs_indices[i] = -1; @@ -93,26 +93,27 @@ void setup_device(PhysicalDevice *dev) { struct input_absinfo abs; ioctl(dev->event, EVIOCGABS(i), &abs); - uint16_t index = dev->device_info.abs_count++; + uint16_t index = dev->device_info.abs.len++; - dev->device_info.abs_min[index] = abs.minimum; - dev->device_info.abs_max[index] = abs.maximum; - dev->device_info.abs_fuzz[index] = abs.fuzz; - dev->device_info.abs_flat[index] = abs.flat; - dev->device_info.abs_res[index] = abs.resolution; + Abs *dev_abs = &dev->device_info.abs.data[index]; + dev_abs->min = abs.minimum; + dev_abs->max = abs.maximum; + dev_abs->fuzz = abs.fuzz; + dev_abs->flat = abs.flat; + dev_abs->res = abs.resolution; + dev_abs->id = i; // Bidirectional mapping id <-> index // We need this to avoid wasting space in packets because ids are sparse - dev->device_info.abs_id[index] = i; - dev->mapping.abs_indices[i] = index; + dev->mapping.abs_indices[i] = index; } else if (type == EV_REL) { - uint16_t index = dev->device_info.rel_count++; + uint16_t index = dev->device_info.rel.len++; - dev->device_info.rel_id[index] = i; + dev->device_info.rel.data[index].id = i; dev->mapping.rel_indices[i] = index; } else if (type == EV_KEY) { - uint16_t index = dev->device_info.key_count++; + uint16_t index = dev->device_info.key.len++; - dev->device_info.key_id[index] = i; + dev->device_info.key.data[index].id = i; dev->mapping.key_indices[i] = index; } } @@ -445,7 +446,7 @@ void poll_devices(void) { } // "Execute" a MessageControllerState: set the led color, rumble and flash using the hidraw interface (Dualshock 4 only) -void apply_controller_state(Controller *c, MessageControllerState *state) { +void apply_controller_state(Controller *c, DeviceControllerState *state) { if (c->ctr.ps4_hidraw && c->dev.hidraw < 0) { printf("HID: Trying to apply controller state on incompatible device (%lu)\n", c->dev.id); return; diff --git a/hid.h b/hid.h index daf6191..8bc308b 100644 --- a/hid.h +++ b/hid.h @@ -30,7 +30,7 @@ typedef struct { uint64_t id; char *name; DeviceMap mapping; - MessageDeviceInfo device_info; + DeviceInfo device_info; } PhysicalDevice; typedef struct { @@ -42,6 +42,6 @@ void *hid_thread(void *arg); void return_device(Controller *c); void forget_device(Controller *c); bool get_device(char **tags, size_t tag_count, bool *stop, Controller *res, uint8_t *index); -void apply_controller_state(Controller *c, MessageControllerState *state); +void apply_controller_state(Controller *c, DeviceControllerState *state); #endif diff --git a/net.c b/net.c index 00f8a6e..03fe02d 100644 --- a/net.c +++ b/net.c @@ -1,434 +1,366 @@ +// Generated file, do not edit (its not like it'll explode if you do, but its better not to) #include "net.h" - -#include "util.h" - -#include #include -#include -// Deserialize the message in buf, buf must be at least 4 aligned. Returns -1 on error, otherwise returns 0 -// and writes result to dst -int msg_deserialize(const uint8_t *buf, size_t len, Message *restrict dst) { - { - if (len <= MAGIC_SIZE) { - return -1; - } +__attribute__((unused)) static int abs_serialize(struct Abs val, byte *buf); +__attribute__((unused)) static int abs_deserialize(struct Abs *val, const byte *buf); +__attribute__((unused)) static void abs_free(struct Abs val); +__attribute__((unused)) static int key_serialize(struct Key val, byte *buf); +__attribute__((unused)) static int key_deserialize(struct Key *val, const byte *buf); +__attribute__((unused)) static void key_free(struct Key val); +__attribute__((unused)) static int rel_serialize(struct Rel val, byte *buf); +__attribute__((unused)) static int rel_deserialize(struct Rel *val, const byte *buf); +__attribute__((unused)) static void rel_free(struct Rel val); +__attribute__((unused)) static int tag_list_serialize(struct TagList val, byte *buf); +__attribute__((unused)) static int tag_list_deserialize(struct TagList *val, const byte *buf); +__attribute__((unused)) static void tag_list_free(struct TagList val); +__attribute__((unused)) static int tag_serialize(struct Tag val, byte *buf); +__attribute__((unused)) static int tag_deserialize(struct Tag *val, const byte *buf); +__attribute__((unused)) static void tag_free(struct Tag val); - if (*(MAGIC_TYPE *)buf != MAGIC_BEG) { - printf("NET: No magic in message\n"); - return -1; - } +static int abs_serialize(struct Abs val, byte *buf) { + byte * base_buf = buf; + *(uint32_t *)&buf[0] = val.min; + *(uint32_t *)&buf[4] = val.max; + *(uint32_t *)&buf[8] = val.fuzz; + *(uint32_t *)&buf[12] = val.flat; + *(uint32_t *)&buf[16] = val.res; + *(uint16_t *)&buf[20] = val.id; + buf += 24; + return (int)(buf - base_buf); +} +static int abs_deserialize(struct Abs *val, const byte *buf) { + const byte * base_buf = buf; + val->min = *(uint32_t *)&buf[0]; + val->max = *(uint32_t *)&buf[4]; + val->fuzz = *(uint32_t *)&buf[8]; + val->flat = *(uint32_t *)&buf[12]; + val->res = *(uint32_t *)&buf[16]; + val->id = *(uint16_t *)&buf[20]; + buf += 24; + return (int)(buf - base_buf); +} +static void abs_free(struct Abs val) { } - buf += MAGIC_SIZE; - len -= MAGIC_SIZE; +static int key_serialize(struct Key val, byte *buf) { + byte * base_buf = buf; + *(uint16_t *)&buf[0] = val.id; + buf += 2; + return (int)(buf - base_buf); +} +static int key_deserialize(struct Key *val, const byte *buf) { + const byte * base_buf = buf; + val->id = *(uint16_t *)&buf[0]; + buf += 2; + return (int)(buf - base_buf); +} +static void key_free(struct Key val) { } + +static int rel_serialize(struct Rel val, byte *buf) { + byte * base_buf = buf; + *(uint16_t *)&buf[0] = val.id; + buf += 2; + return (int)(buf - base_buf); +} +static int rel_deserialize(struct Rel *val, const byte *buf) { + const byte * base_buf = buf; + val->id = *(uint16_t *)&buf[0]; + buf += 2; + return (int)(buf - base_buf); +} +static void rel_free(struct Rel val) { } + +static int tag_list_serialize(struct TagList val, byte *buf) { + byte * base_buf = buf; + *(uint16_t *)&buf[0] = val.tags.len; + buf += 2; + for(size_t i = 0; i < val.tags.len; i++) { + typeof(val.tags.data[i]) e0 = val.tags.data[i]; + buf += tag_serialize(e0, &buf[0]); } - // Decrement len so that it becomes the len of the data without the code. - if (len-- < 1) - return -1; - // This ensures that only a byte is read instead of a full enum value - uint8_t code_byte = buf[0]; - MessageCode code = (MessageCode)code_byte; - uint32_t size = 0; - - uint16_t abs, rel, key, *buf16; - uint8_t index, slot; - - switch (code) { - case DeviceInfo: - if (len < 7) - return -1; - slot = buf[2]; - index = buf[3]; - // buf + 4: a byte for, code, padding, slot, index - buf16 = (uint16_t *)(buf + 4); - abs = buf16[0]; - rel = buf16[1]; - key = buf16[2]; - buf += 12; - if (MSS_DEVICE_INFO(abs, rel, key) > len) - return -1; - - dst->device_info.code = code; - dst->device_info.slot = slot; - dst->device_info.index = index; - dst->device_info.abs_count = abs; - dst->device_info.rel_count = rel; - dst->device_info.key_count = key; - - // SOA in c but serialized as AOS - for (int i = 0; i < abs; i++) { - // buf + 4: 2 bytes for id and 2 bytes of padding - uint32_t *buf32 = (uint32_t *)(buf + 4); - - dst->device_info.abs_id[i] = *(uint16_t *)buf; - dst->device_info.abs_min[i] = buf32[0]; - dst->device_info.abs_max[i] = buf32[1]; - dst->device_info.abs_fuzz[i] = buf32[2]; - dst->device_info.abs_flat[i] = buf32[3]; - dst->device_info.abs_res[i] = buf32[4]; - - buf += 24; - } - - for (int i = 0; i < rel; i++) { - dst->device_info.rel_id[i] = *(uint16_t *)buf; - buf += 2; - } - - for (int i = 0; i < key; i++) { - dst->device_info.key_id[i] = *(uint16_t *)buf; - buf += 2; - } - - size = MSS_DEVICE_INFO(abs, rel, key) + 1; - break; - case DeviceReport: - if (len < 7) - return -1; - - slot = buf[2]; - index = buf[3]; - // buf + 4: a byte for, code, padding, slot and index - buf16 = (uint16_t *)(buf + 4); - abs = buf16[0]; - rel = buf16[1]; - key = buf16[2]; - buf += 12; - if (len < MSS_DEVICE_REPORT(abs, rel, key)) - return -1; - - dst->device_report.code = code; - dst->device_report.slot = slot; - dst->device_report.index = index; - dst->device_report.abs_count = abs; - dst->device_report.rel_count = rel; - dst->device_report.key_count = key; - - for (int i = 0; i < abs; i++) { - dst->device_report.abs[i] = *(uint32_t *)buf; - buf += 4; - } - - for (int i = 0; i < rel; i++) { - dst->device_report.rel[i] = *(uint32_t *)buf; - buf += 4; - } - - for (int i = 0; i < key; i++) - dst->device_report.key[i] = *(buf++); - - buf += align_4(key) - key; - - size = MSS_DEVICE_REPORT(abs, rel, key) + 1; - break; - case ControllerState: - if (len < MSS_CONTROLLER_STATE) - return -1; - - dst->code = code; - dst->controller_state.index = *(uint16_t *)(buf + 2); - dst->controller_state.led[0] = buf[4]; - dst->controller_state.led[1] = buf[5]; - dst->controller_state.led[2] = buf[6]; - dst->controller_state.small_rumble = buf[7]; - dst->controller_state.big_rumble = buf[8]; - dst->controller_state.flash_on = buf[9]; - dst->controller_state.flash_off = buf[10]; - size = MSS_CONTROLLER_STATE + 1; - buf += size; - break; - case Request: { - if (len < 3) - return -1; - - dst->code = code; - dst->request.request_count = *(uint16_t *)(buf + 2); - buf += 4; // 1 bytes for code, 1 byte for padding and 2 bytes for count - - int count = dst->request.request_count; - TagList *reqs = malloc(count * sizeof(TagList)); - // The length of the message, will be updated as we read more. - int expected_len = 3; - - for (int i = 0; i < dst->request.request_count; i++) { - expected_len += 2; - if (len < expected_len) { - return -1; - } - - TagList *tags = &reqs[i]; - tags->count = *(uint16_t *)buf; - tags->tags = malloc(tags->count * sizeof(char *)); - buf += 2; - - for (int j = 0; j < tags->count; j++) { - expected_len += 2; - if (len < expected_len) { - return -1; - } - - uint16_t str_len = *(uint16_t *)buf; - buf += 2; - - expected_len += align_2(str_len); - if (len < expected_len) { - return -1; - } - - char *str = malloc(str_len + 1); - str[str_len] = '\0'; - - strncpy(str, (char *)buf, str_len); - - tags->tags[j] = str; - - buf += align_2(str_len); - } - } - - dst->request.requests = reqs; - size = expected_len + 1; - break; + buf = (byte*)(((((uintptr_t)buf - 1) >> 1) + 1) << 1); + return (int)(buf - base_buf); +} +static int tag_list_deserialize(struct TagList *val, const byte *buf) { + const byte * base_buf = buf; + val->tags.len = *(uint16_t *)&buf[0]; + buf += 2; + val->tags.data = malloc(val->tags.len * sizeof(typeof(*val->tags.data))); + for(size_t i = 0; i < val->tags.len; i++) { + typeof(&val->tags.data[i]) e0 = &val->tags.data[i]; + buf += tag_deserialize(e0, &buf[0]); } - case DeviceDestroy: - if (len < MSS_DESTROY) - return -1; - - dst->code = code; - dst->destroy.index = *(uint16_t *)(buf + 2); - size = MSS_DESTROY + 1; - buf += size; - break; - default: - return -1; + buf = (byte*)(((((uintptr_t)buf - 1) >> 1) + 1) << 1); + return (int)(buf - base_buf); +} +static void tag_list_free(struct TagList val) { + for(size_t i = 0; i < val.tags.len; i++) { + typeof(val.tags.data[i]) e0 = val.tags.data[i]; + tag_free(e0); } - - if (align_m(size) + MAGIC_SIZE > len + 1) { - return -1; - } - - // WARN: This is technically bad, but should be ok nonetheless - MAGIC_TYPE *mbuf = (MAGIC_TYPE *)align_m((uintptr_t)buf); - - if (*mbuf != MAGIC_END) { - printf("NET: Magic not found\n"); - return -1; - } - - return align_m(size) + 2 * MAGIC_SIZE; + free(val.tags.data); } -// Serialize the message msg in buf, buf must be at least 4 aligned. Returns -1 on error (buf not big enough); -int msg_serialize(uint8_t *restrict buf, size_t len, const Message *msg) { - // If len is less than the two magic and the code we can't serialize any message - if (len < MAGIC_SIZE * 2 + 1) +static int tag_serialize(struct Tag val, byte *buf) { + byte * base_buf = buf; + *(uint16_t *)&buf[0] = val.name.len; + buf += 2; + for(size_t i = 0; i < val.name.len; i++) { + typeof(val.name.data[i]) e0 = val.name.data[i]; + *(char *)&buf[0] = e0; + buf += 1; + } + buf = (byte*)(((((uintptr_t)buf - 1) >> 1) + 1) << 1); + return (int)(buf - base_buf); +} +static int tag_deserialize(struct Tag *val, const byte *buf) { + const byte * base_buf = buf; + val->name.len = *(uint16_t *)&buf[0]; + buf += 2; + val->name.data = malloc(val->name.len * sizeof(typeof(*val->name.data))); + for(size_t i = 0; i < val->name.len; i++) { + typeof(&val->name.data[i]) e0 = &val->name.data[i]; + *e0 = *(char *)&buf[0]; + buf += 1; + } + buf = (byte*)(((((uintptr_t)buf - 1) >> 1) + 1) << 1); + return (int)(buf - base_buf); +} +static void tag_free(struct Tag val) { + free(val.name.data); +} + +int msg_device_serialize(byte *buf, size_t len, DeviceMessage *msg) { + const byte *base_buf = buf; + if(len < 2 * MSG_MAGIC_SIZE) return -1; - - *(MAGIC_TYPE *)buf = MAGIC_BEG; - buf += MAGIC_SIZE; - len -= MAGIC_SIZE + 1; - - uint16_t abs, rel, key, *buf16; - uint32_t size; - - switch (msg->code) { - case DeviceInfo: - abs = msg->device_info.abs_count; - rel = msg->device_info.rel_count; - key = msg->device_info.key_count; - if (len < MSS_DEVICE_INFO(abs, rel, key)) - return -1; - - buf[0] = (uint8_t)msg->code; - // 1 byte of padding - buf[2] = (uint8_t)msg->device_info.slot; - buf[3] = (uint8_t)msg->device_info.index; - // buf + 4: a byte for, code, padding, slot, index - buf16 = (uint16_t *)(buf + 4); - buf16[0] = abs; - buf16[1] = rel; - buf16[2] = key; - buf += 12; - - // Back to 4 aligned - for (int i = 0; i < abs; i++) { - // buf + 4: 2 bytes for id and 2 bytes of padding - uint32_t *buf32 = (uint32_t *)(buf + 4); - - *(uint16_t *)buf = msg->device_info.abs_id[i]; - - buf32[0] = msg->device_info.abs_min[i]; - buf32[1] = msg->device_info.abs_max[i]; - buf32[2] = msg->device_info.abs_fuzz[i]; - buf32[3] = msg->device_info.abs_flat[i]; - buf32[4] = msg->device_info.abs_res[i]; - - buf += 24; - } - // Still 4 aligned - for (int i = 0; i < rel; i++) { - *(uint16_t *)buf = msg->device_info.rel_id[i]; - buf += 2; - } - - for (int i = 0; i < key; i++) { - *(uint16_t *)buf = msg->device_info.key_id[i]; - buf += 2; - } - - size = MSS_DEVICE_INFO(abs, rel, key) + 1; + *(MsgMagic*)buf = MSG_MAGIC_START; + buf += MSG_MAGIC_SIZE; + switch(msg->tag) { + case DeviceTagNone: break; - case DeviceReport: - abs = msg->device_report.abs_count; - rel = msg->device_report.rel_count; - key = msg->device_report.key_count; - if (len < MSS_DEVICE_REPORT(abs, rel, key)) - return -1; - - buf[0] = (uint8_t)msg->code; - // 1 byte of padding - buf[2] = msg->device_report.slot; - buf[3] = msg->device_report.index; - // buf + 4: a byte for, code, padding, slot and index - buf16 = (uint16_t *)(buf + 4); - buf16[0] = abs; - buf16[1] = rel; - buf16[2] = key; - buf += 12; - // We're 4 aligned already - for (int i = 0; i < abs; i++) { - *(uint32_t *)buf = msg->device_report.abs[i]; + case DeviceTagInfo: { + *(uint16_t *)buf = DeviceTagInfo; + *(uint16_t *)&buf[2] = msg->info.key.len; + *(uint8_t *)&buf[4] = msg->info.slot; + *(uint8_t *)&buf[5] = msg->info.index; + *(uint8_t *)&buf[6] = msg->info.abs.len; + *(uint8_t *)&buf[7] = msg->info.rel.len; + buf += 8; + for(size_t i = 0; i < msg->info.abs.len; i++) { + typeof(msg->info.abs.data[i]) e0 = msg->info.abs.data[i]; + buf += abs_serialize(e0, &buf[0]); + } + for(size_t i = 0; i < msg->info.rel.len; i++) { + typeof(msg->info.rel.data[i]) e0 = msg->info.rel.data[i]; + buf += rel_serialize(e0, &buf[0]); + } + for(size_t i = 0; i < msg->info.key.len; i++) { + typeof(msg->info.key.data[i]) e0 = msg->info.key.data[i]; + buf += key_serialize(e0, &buf[0]); + } + buf = (byte*)(((((uintptr_t)buf - 1) >> 3) + 1) << 3); + break; + } + case DeviceTagReport: { + *(uint16_t *)buf = DeviceTagReport; + *(uint16_t *)&buf[2] = msg->report.key.len; + *(uint8_t *)&buf[4] = msg->report.slot; + *(uint8_t *)&buf[5] = msg->report.index; + *(uint8_t *)&buf[6] = msg->report.abs.len; + *(uint8_t *)&buf[7] = msg->report.rel.len; + buf += 8; + for(size_t i = 0; i < msg->report.abs.len; i++) { + typeof(msg->report.abs.data[i]) e0 = msg->report.abs.data[i]; + *(uint32_t *)&buf[0] = e0; buf += 4; } - // Still 4 aligned - for (int i = 0; i < rel; i++) { - *(uint32_t *)buf = msg->device_report.rel[i]; + for(size_t i = 0; i < msg->report.rel.len; i++) { + typeof(msg->report.rel.data[i]) e0 = msg->report.rel.data[i]; + *(uint32_t *)&buf[0] = e0; buf += 4; } - // Doesn't matter since we're writing individual bytes - for (int i = 0; i < key; i++) - *(buf++) = msg->device_report.key[i]; - - size = MSS_DEVICE_REPORT(abs, rel, key) + 1; - buf += align_4(key) - key; - break; - case ControllerState: - if (len < MSS_CONTROLLER_STATE) - return -1; - - buf[0] = (uint8_t)msg->code; - - *(uint16_t *)(buf + 2) = msg->controller_state.index; - - buf[4] = msg->controller_state.led[0]; - buf[5] = msg->controller_state.led[1]; - buf[6] = msg->controller_state.led[2]; - buf[7] = msg->controller_state.small_rumble; - buf[8] = msg->controller_state.big_rumble; - buf[9] = msg->controller_state.flash_on; - buf[10] = msg->controller_state.flash_off; - size = MSS_CONTROLLER_STATE + 1; - buf += size; - break; - case Request: { - int expected_len = MSS_REQUEST(msg->request.request_count); - if (len < expected_len) - return -1; - - buf[0] = (uint8_t)msg->code; - buf += 2; - *(uint16_t *)buf = msg->request.request_count; - buf += 2; - - for (int i = 0; i < msg->request.request_count; i++) { - - uint16_t tag_count = msg->request.requests[i].count; - char **tags = msg->request.requests[i].tags; - - *(uint16_t *)buf = tag_count; - - buf += 2; - - for (int j = 0; j < tag_count; j++) { - int str_len = strlen(tags[j]); - int byte_len = align_2(str_len); - - expected_len += 2 + byte_len; - if (len < expected_len) { - return -1; - } - - *(uint16_t *)buf = str_len; - buf += 2; - - strncpy((char *)buf, tags[j], str_len); - buf += byte_len; - } + for(size_t i = 0; i < msg->report.key.len; i++) { + typeof(msg->report.key.data[i]) e0 = msg->report.key.data[i]; + *(uint8_t *)&buf[0] = e0; + buf += 1; } - - size = expected_len + 1; + buf = (byte*)(((((uintptr_t)buf - 1) >> 3) + 1) << 3); break; } - case DeviceDestroy: - if (len < MSS_DESTROY) - return -1; - - buf[0] = (uint8_t)msg->code; - - *(uint16_t *)(buf + 2) = msg->controller_state.index; - size = MSS_DESTROY + 1; - buf += size; + case DeviceTagControllerState: { + *(uint16_t *)buf = DeviceTagControllerState; + *(uint16_t *)&buf[2] = msg->controller_state.index; + *(uint8_t *)&buf[4] = msg->controller_state.led[0]; + *(uint8_t *)&buf[5] = msg->controller_state.led[1]; + *(uint8_t *)&buf[6] = msg->controller_state.led[2]; + *(uint8_t *)&buf[7] = msg->controller_state.small_rumble; + *(uint8_t *)&buf[8] = msg->controller_state.big_rumble; + *(uint8_t *)&buf[9] = msg->controller_state.flash_on; + *(uint8_t *)&buf[10] = msg->controller_state.flash_off; + buf += 16; break; - default: - printf("ERR(msg_serialize): Trying to serialize unknown message of code %d\n", msg->code); + } + case DeviceTagRequest: { + *(uint16_t *)buf = DeviceTagRequest; + msg->request._version = 1UL; + *(uint64_t *)&buf[8] = msg->request._version; + *(uint16_t *)&buf[16] = msg->request.requests.len; + buf += 18; + for(size_t i = 0; i < msg->request.requests.len; i++) { + typeof(msg->request.requests.data[i]) e0 = msg->request.requests.data[i]; + buf += tag_list_serialize(e0, &buf[0]); + } + buf = (byte*)(((((uintptr_t)buf - 1) >> 3) + 1) << 3); + break; + } + case DeviceTagDestroy: { + *(uint16_t *)buf = DeviceTagDestroy; + *(uint16_t *)&buf[2] = msg->destroy.index; + buf += 8; + break; + } + } + *(MsgMagic*)buf = MSG_MAGIC_END; + buf += MSG_MAGIC_SIZE; + if(buf > base_buf + len) + return -1; + return (int)(buf - base_buf); +} + +int msg_device_deserialize(const byte *buf, size_t len, DeviceMessage *msg) { + const byte *base_buf = buf; + if(len < 2 * MSG_MAGIC_SIZE) + return -1; + if(*(MsgMagic*)buf != MSG_MAGIC_START) + return -1; + buf += MSG_MAGIC_SIZE; + DeviceTag tag = *(uint16_t*)buf; + switch(tag) { + case DeviceTagNone: + break; + case DeviceTagInfo: { + msg->tag = DeviceTagInfo; + msg->info.key.len = *(uint16_t *)&buf[2]; + msg->info.slot = *(uint8_t *)&buf[4]; + msg->info.index = *(uint8_t *)&buf[5]; + msg->info.abs.len = *(uint8_t *)&buf[6]; + msg->info.rel.len = *(uint8_t *)&buf[7]; + buf += 8; + for(size_t i = 0; i < msg->info.abs.len; i++) { + typeof(&msg->info.abs.data[i]) e0 = &msg->info.abs.data[i]; + buf += abs_deserialize(e0, &buf[0]); + } + for(size_t i = 0; i < msg->info.rel.len; i++) { + typeof(&msg->info.rel.data[i]) e0 = &msg->info.rel.data[i]; + buf += rel_deserialize(e0, &buf[0]); + } + for(size_t i = 0; i < msg->info.key.len; i++) { + typeof(&msg->info.key.data[i]) e0 = &msg->info.key.data[i]; + buf += key_deserialize(e0, &buf[0]); + } + buf = (byte*)(((((uintptr_t)buf - 1) >> 3) + 1) << 3); + break; + } + case DeviceTagReport: { + msg->tag = DeviceTagReport; + msg->report.key.len = *(uint16_t *)&buf[2]; + msg->report.slot = *(uint8_t *)&buf[4]; + msg->report.index = *(uint8_t *)&buf[5]; + msg->report.abs.len = *(uint8_t *)&buf[6]; + msg->report.rel.len = *(uint8_t *)&buf[7]; + buf += 8; + for(size_t i = 0; i < msg->report.abs.len; i++) { + typeof(&msg->report.abs.data[i]) e0 = &msg->report.abs.data[i]; + *e0 = *(uint32_t *)&buf[0]; + buf += 4; + } + for(size_t i = 0; i < msg->report.rel.len; i++) { + typeof(&msg->report.rel.data[i]) e0 = &msg->report.rel.data[i]; + *e0 = *(uint32_t *)&buf[0]; + buf += 4; + } + for(size_t i = 0; i < msg->report.key.len; i++) { + typeof(&msg->report.key.data[i]) e0 = &msg->report.key.data[i]; + *e0 = *(uint8_t *)&buf[0]; + buf += 1; + } + buf = (byte*)(((((uintptr_t)buf - 1) >> 3) + 1) << 3); + break; + } + case DeviceTagControllerState: { + msg->tag = DeviceTagControllerState; + msg->controller_state.index = *(uint16_t *)&buf[2]; + msg->controller_state.led[0] = *(uint8_t *)&buf[4]; + msg->controller_state.led[1] = *(uint8_t *)&buf[5]; + msg->controller_state.led[2] = *(uint8_t *)&buf[6]; + msg->controller_state.small_rumble = *(uint8_t *)&buf[7]; + msg->controller_state.big_rumble = *(uint8_t *)&buf[8]; + msg->controller_state.flash_on = *(uint8_t *)&buf[9]; + msg->controller_state.flash_off = *(uint8_t *)&buf[10]; + buf += 16; + break; + } + case DeviceTagRequest: { + msg->tag = DeviceTagRequest; + msg->request._version = *(uint64_t *)&buf[8]; + msg->request.requests.len = *(uint16_t *)&buf[16]; + buf += 18; + msg->request.requests.data = malloc(msg->request.requests.len * sizeof(typeof(*msg->request.requests.data))); + for(size_t i = 0; i < msg->request.requests.len; i++) { + typeof(&msg->request.requests.data[i]) e0 = &msg->request.requests.data[i]; + buf += tag_list_deserialize(e0, &buf[0]); + } + buf = (byte*)(((((uintptr_t)buf - 1) >> 3) + 1) << 3); + if(msg->request._version != 1UL) { + printf("Mismatched version: peers aren't the same version, expected 1 got %lu.\n", msg->request._version); + msg_device_free(msg); + return -1; + } + break; + } + case DeviceTagDestroy: { + msg->tag = DeviceTagDestroy; + msg->destroy.index = *(uint16_t *)&buf[2]; + buf += 8; + break; + } + } + if(*(MsgMagic*)buf != MSG_MAGIC_END) { + msg_device_free(msg); return -1; } - - if (align_m(size) + MAGIC_SIZE > len) { + buf += MSG_MAGIC_SIZE; + if(buf > base_buf + len) { + msg_device_free(msg); return -1; } - - MAGIC_TYPE *mbuf = (MAGIC_TYPE *)align_m((uintptr_t)buf); - - *mbuf = MAGIC_END; - - return align_m(size) + MAGIC_SIZE * 2; + return (int)(buf - base_buf); } -void msg_free(Message *msg) { - if (msg->code == Request) { - for (int i = 0; i < msg->request.request_count; i++) { - for (int j = 0; j < msg->request.requests[i].count; j++) { - free(msg->request.requests[i].tags[j]); - } - free(msg->request.requests[i].tags); - } - free(msg->request.requests); - } -} - -void print_message_buffer(const uint8_t *buf, int len) { - bool last_beg = false; - for (int i = 0; i < len; i++) { - if (i + MAGIC_SIZE <= len) { - MAGIC_TYPE magic = *(MAGIC_TYPE *)(&buf[i]); - if (magic == MAGIC_BEG) { - printf(" \033[32m%08X\033[0m", magic); - i += MAGIC_SIZE - 1; - last_beg = true; - continue; - } else if (magic == MAGIC_END) { - printf(" \033[32m%08X\033[0m", magic); - i += MAGIC_SIZE - 1; - continue; - } - } - - if (last_beg) { - last_beg = false; - printf(" \033[034m%02X\033[0m", buf[i]); - } else { - printf(" %02X", buf[i]); +void msg_device_free(DeviceMessage *msg) { + switch(msg->tag) { + case DeviceTagNone: + break; + case DeviceTagInfo: { + break; + } + case DeviceTagReport: { + break; + } + case DeviceTagControllerState: { + break; + } + case DeviceTagRequest: { + for(size_t i = 0; i < msg->request.requests.len; i++) { + typeof(msg->request.requests.data[i]) e0 = msg->request.requests.data[i]; + tag_list_free(e0); } + free(msg->request.requests.data); + break; + } + case DeviceTagDestroy: { + break; + } } } diff --git a/net.h b/net.h index 1ae50ce..493a07d 100644 --- a/net.h +++ b/net.h @@ -1,122 +1,133 @@ -// vi:ft=c -#ifndef NET_H_ -#define NET_H_ -#include "util.h" - -#include +// Generated file, do not edit (its not like it'll explode if you do, but its better not to) +#ifndef NET_H +#define NET_H #include #include +#include -#define MAGIC_TYPE uint32_t -#define MAGIC_SIZE sizeof(MAGIC_TYPE) -static const MAGIC_TYPE MAGIC_BEG = 0xDEADCAFE; -static const MAGIC_TYPE MAGIC_END = 0xCAFEDEAD; -// Align n to the next MAGIC boundary -static inline size_t align_m(uintptr_t n) { return (((n - 1) >> 2) + 1) << 2; } +typedef unsigned char byte; +typedef uint64_t MsgMagic; -typedef enum { - NoMessage = 0, - DeviceInfo = 1, - DeviceReport = 2, - DeviceDestroy = 3, - ControllerState = 4, - Request = 5, -} MessageCode; +#define MSG_MAGIC_SIZE sizeof(MsgMagic) +static const MsgMagic MSG_MAGIC_START = 0xCAFEF00DBEEFDEAD; +static const MsgMagic MSG_MAGIC_END = 0xF00DBEEFCAFEDEAD; -// Alignment 4 -typedef struct { - MessageCode code; - // + 1 byte of padding +typedef struct Abs { + uint16_t id; + uint32_t min; + uint32_t max; + uint32_t fuzz; + uint32_t flat; + uint32_t res; +} Abs; - uint8_t slot; - uint8_t index; +typedef struct Key { + uint16_t id; +} Key; - uint16_t abs_count; - uint16_t rel_count; - uint16_t key_count; +typedef struct Rel { + uint16_t id; +} Rel; - uint16_t abs_id[ABS_CNT]; - // + 2 bytes of padding per abs - uint32_t abs_min[ABS_CNT]; - uint32_t abs_max[ABS_CNT]; - uint32_t abs_fuzz[ABS_CNT]; - uint32_t abs_flat[ABS_CNT]; - uint32_t abs_res[ABS_CNT]; - - uint16_t rel_id[REL_CNT]; - - uint16_t key_id[KEY_CNT]; -} MessageDeviceInfo; -#define MSS_DEVICE_INFO(abs, rel, key) (10 + abs * 24 + rel * 2 + key * 2 + 1) -// MSS -> Message Serialized Size: -// Size of the data of the message when serialized (no alignment / padding) - -// 4 aligned -typedef struct { - MessageCode code; - // + 1 byte of padding - - uint8_t slot; - uint8_t index; - - uint16_t abs_count; - uint16_t rel_count; - uint16_t key_count; - - uint32_t abs[ABS_CNT]; - uint32_t rel[REL_CNT]; - uint8_t key[KEY_CNT]; -} MessageDeviceReport; -#define MSS_DEVICE_REPORT(abs, rel, key) (11 + abs * 4 + rel * 4 + align_4(key)) - -// 1 aligned -typedef struct { - MessageCode code; - // + 1 byte of padding - - uint16_t index; - uint8_t led[3]; - uint8_t small_rumble; - uint8_t big_rumble; - uint8_t flash_on; - uint8_t flash_off; -} MessageControllerState; -#define MSS_CONTROLLER_STATE 10 - -typedef struct { - char **tags; - uint16_t count; +typedef struct TagList { + struct { + uint16_t len; + struct Tag *data; + } tags; } TagList; -typedef struct { - MessageCode code; - // + 1 byte of padding +typedef struct Tag { + struct { + uint16_t len; + char *data; + } name; +} Tag; - TagList *requests; - uint16_t request_count; -} MessageRequest; -#define MSS_REQUEST(count) (2 + 2 * count + 1) +// Device -typedef struct { - MessageCode code; - // + 1 byte of padding +typedef enum DeviceTag { + DeviceTagNone = 0, + DeviceTagInfo = 1, + DeviceTagReport = 2, + DeviceTagControllerState = 3, + DeviceTagRequest = 4, + DeviceTagDestroy = 5, +} DeviceTag; +typedef struct DeviceInfo { + DeviceTag tag; + uint8_t slot; + uint8_t index; + struct { + uint8_t len; + struct Abs data[64]; + } abs; + struct { + uint8_t len; + struct Rel data[16]; + } rel; + struct { + uint16_t len; + struct Key data[768]; + } key; +} DeviceInfo; + +typedef struct DeviceReport { + DeviceTag tag; + uint8_t slot; + uint8_t index; + struct { + uint8_t len; + uint32_t data[64]; + } abs; + struct { + uint8_t len; + uint32_t data[16]; + } rel; + struct { + uint16_t len; + uint8_t data[768]; + } key; +} DeviceReport; + +typedef struct DeviceControllerState { + DeviceTag tag; uint16_t index; -} MessageDestroy; -#define MSS_DESTROY 3 + uint8_t led[3]; + uint8_t small_rumble; + uint8_t big_rumble; + uint8_t flash_on; + uint8_t flash_off; +} DeviceControllerState; -typedef union { - MessageCode code; - MessageRequest request; - MessageDestroy destroy; - MessageDeviceInfo device_info; - MessageDeviceReport device_report; - MessageControllerState controller_state; -} Message; +typedef struct DeviceRequest { + DeviceTag tag; + struct { + uint16_t len; + struct TagList *data; + } requests; + uint64_t _version; +} DeviceRequest; -int msg_deserialize(const uint8_t *buf, size_t len, Message *restrict dst); -int msg_serialize(uint8_t *restrict buf, size_t len, const Message *msg); -void msg_free(Message *msg); -void print_message_buffer(const uint8_t *buf, int len); +typedef struct DeviceDestroy { + DeviceTag tag; + uint16_t index; +} DeviceDestroy; +typedef union DeviceMessage { + DeviceTag tag; + DeviceInfo info; + DeviceReport report; + DeviceControllerState controller_state; + DeviceRequest request; + DeviceDestroy destroy; +} DeviceMessage; + +// Serialize the message msg to buffer dst of size len, returns the length of the serialized message, or -1 on error (buffer overflow) +int msg_device_serialize(byte *dst, size_t len, DeviceMessage *msg); +// Deserialize the message in the buffer src of size len into dst, return the length of the serialized message or -1 on error. +int msg_device_deserialize(const byte *src, size_t len, DeviceMessage *dst); + +// Free the message (created by msg_device_deserialize) +void msg_device_free(DeviceMessage *msg); #endif diff --git a/net.ser b/net.ser new file mode 100644 index 0000000..6ce0422 --- /dev/null +++ b/net.ser @@ -0,0 +1,63 @@ +struct Abs { + id: u16, + min: u32, + max: u32, + fuzz: u32, + flat: u32, + res: u32, +} + +struct Rel { + id: u16, +} + +struct Key { + id: u16, +} + +const ABS_CNT = 64; +const REL_CNT = 16; +const KEY_CNT = 768; + +struct Tag { + name: char[], +} + +struct TagList { + tags: Tag[], +} + +version(1); +messages Device { + Info { + slot: u8, + index: u8, + + abs: Abs[^ABS_CNT], + rel: Rel[^REL_CNT], + key: Key[^KEY_CNT], + } + Report { + slot: u8, + index: u8, + + abs: u32[^ABS_CNT], + rel: u32[^REL_CNT], + key: u8[^KEY_CNT], + } + ControllerState { + index: u16, + led: u8[3], + small_rumble: u8, + big_rumble: u8, + flash_on: u8, + flash_off: u8, + } + #[versioned] + Request { + requests: TagList[], + } + Destroy { + index: u16, + } +} diff --git a/ser/.ccls b/ser/.ccls new file mode 100644 index 0000000..e671fa2 --- /dev/null +++ b/ser/.ccls @@ -0,0 +1 @@ +clang diff --git a/ser/.clang-format b/ser/.clang-format new file mode 100644 index 0000000..6355949 --- /dev/null +++ b/ser/.clang-format @@ -0,0 +1,12 @@ +# vi:ft=yaml +BasedOnStyle: LLVM +IndentWidth: 4 +AlignArrayOfStructures: Left +PointerAlignment: Right +ColumnLimit: 130 +IncludeBlocks: Regroup +BinPackArguments: false +BinPackParameters: false +AlignAfterOpenBracket: BlockIndent +AllowAllArgumentsOnNextLine: false +AlignEscapedNewlines: DontAlign diff --git a/ser/.gitignore b/ser/.gitignore new file mode 100644 index 0000000..af3a786 --- /dev/null +++ b/ser/.gitignore @@ -0,0 +1,3 @@ +.ccls-cache +objects +ser diff --git a/ser/Makefile b/ser/Makefile new file mode 100644 index 0000000..85adcb8 --- /dev/null +++ b/ser/Makefile @@ -0,0 +1,31 @@ +CC=gcc +CFLAGS=-std=c2x -pedantic -g -Wall -fsanitize=address +LDFLAGS=-lm + +BUILD_DIR=./objects +BIN=./ser +SOURCES=$(wildcard *.c) + +OBJECTS:=$(patsubst %.c,$(BUILD_DIR)/%.o,$(SOURCES)) +DEPS:=$(patsubst %.c,$(BUILD_DIR)/%.d,$(SOURCES)) + +.PHONY: run build clean + +run: $(BIN) + @echo "[exec] $<" + $(BIN) +build: $(BIN) + +-include $(DEPS) + +$(BIN): $(OBJECTS) + @echo "[ld] $@" + $(CC) $(CFLAGS) $^ $(LDFLAGS) -o $@ +$(BUILD_DIR)/%.o: %.c | $(BUILD_DIR) + @echo "[cc] $<" + $(CC) -MMD $(CFLAGS) -c $< -o $@ +$(BUILD_DIR): + mkdir -p $(BUILD_DIR) +clean: + rm -rf $(BUILD_DIR) + rm -f $(BIN) diff --git a/ser/arena_allocator.c b/ser/arena_allocator.c new file mode 100644 index 0000000..53d66b9 --- /dev/null +++ b/ser/arena_allocator.c @@ -0,0 +1,39 @@ +#include "arena_allocator.h" + +#include "assert.h" +#include "vector.h" + +#include +#include + +static ArenaBlock arena_block_alloc(size_t size) { + size = size < ARENA_BLOCK_SIZE ? ARENA_BLOCK_SIZE : size; + byte *ptr = malloc(size); + assert_alloc(ptr); + return (ArenaBlock){.data = ptr, .size = size, .end = ptr + size}; +} + +void arena_block_drop(ArenaBlock block) { free(block.data); } + +ArenaAllocator arena_init() { + ArenaBlock block = arena_block_alloc(ARENA_BLOCK_SIZE); + ArenaBlockVec blocks = vec_init(); + vec_grow(&blocks, 256); + vec_push(&blocks, block); + ArenaBlock *last = blocks.data; + return (ArenaAllocator){.blocks = blocks, .ptr = last->data, .last = last}; +} +void *arena_alloc(ArenaAllocator *alloc, size_t size) { + if (alloc->ptr + size > alloc->last->end) { + ArenaBlock block = arena_block_alloc(size); + vec_push(&alloc->blocks, block); + ArenaBlock *last = &alloc->blocks.data[alloc->blocks.len - 1]; + alloc->ptr = last->data; + alloc->last = last; + } + + byte *ptr = alloc->ptr; + alloc->ptr += size; + return ptr; +} +void arena_drop(ArenaAllocator arena) { vec_drop(arena.blocks); } diff --git a/ser/arena_allocator.h b/ser/arena_allocator.h new file mode 100644 index 0000000..bd4a67a --- /dev/null +++ b/ser/arena_allocator.h @@ -0,0 +1,35 @@ +#ifndef ARENA_ALLOCATOR_H +#define ARENA_ALLOCATOR_H +#include "utils.h" +#include "vector_impl.h" + +#include +#include + +#define ARENA_BLOCK_SIZE 4096 + +typedef struct { + size_t size; + byte *data; + byte *end; +} ArenaBlock; + +void arena_block_drop(ArenaBlock block); + +VECTOR_IMPL(ArenaBlock, ArenaBlockVec, arena_block, arena_block_drop); + +// Simple growing arena allocator +typedef struct { + ArenaBlockVec blocks; + ArenaBlock *last; + byte *ptr; +} ArenaAllocator; + +// Create a new arena allocator +ArenaAllocator arena_init(); +// Allocate size bytes in the arena +void *arena_alloc(ArenaAllocator *alloc, size_t size); +// Destroy the arena, freeing its memory +void arena_drop(ArenaAllocator arena); + +#endif diff --git a/ser/assert.h b/ser/assert.h new file mode 100644 index 0000000..42d8350 --- /dev/null +++ b/ser/assert.h @@ -0,0 +1,37 @@ +#ifndef ASSERT_H +#define ASSERT_H + +#include "log.h" + +// Basic assertion macro (always checks) +#ifdef LOG_DISABLE +#define assert(c, fmt, ...) \ + do { \ + if (!(c)) { \ + fprintf(stderr, fmt "\n" __VA_OPT__(, ) __VA_ARGS__); \ + exit(1); \ + } \ + } while (false) +#else // LOG_DISABLE +#define assert(c, ...) \ + do { \ + if (!(c)) { \ + log_error(__VA_ARGS__); \ + exit(1); \ + } \ + } while (false) +#endif // LOG_DISABLE + +// Only check if NDEBUG isn't defined +#ifdef NDEBUG +#define debug_assert(c, ...) (void)0 +#else +#define debug_assert(c, ...) assert(c, __VA_ARGS__) +#endif + +#define assert_eq(a, b, ...) assert(a == b, __VA_ARGS__) + +// Assert allocation succeeded (var != NULL) +#define assert_alloc(var) debug_assert(var != NULL, "Failed to allocate memory for " #var " (Out of memory ?)") + +#endif diff --git a/ser/ast.c b/ser/ast.c new file mode 100644 index 0000000..5b0e087 --- /dev/null +++ b/ser/ast.c @@ -0,0 +1,121 @@ +#include "ast.h" + +#include "arena_allocator.h" +#include "vector.h" + +AstContext ast_init() { + return (AstContext){ + .root = NULL, + .alloc = arena_init(), + }; +} + +static void ast_node_drop(AstNode *node) { + switch (node->tag) { + case ATStruct: + vec_drop(node->struct_.fields); + break; + case ATMessage: + vec_drop(node->message.fields); + break; + case ATMessages: + for (size_t i = 0; i < node->messages.children.len; i++) { + ast_node_drop((AstNode *)&node->messages.children.data[i]); + } + vec_drop(node->messages.children); + break; + case ATItems: + for (size_t i = 0; i < node->items.items.len; i++) { + ast_node_drop((AstNode *)&node->items.items.data[i]); + } + vec_drop(node->items.items); + break; + default: + break; + } +} + +void ast_drop(AstContext ctx) { + if (ctx.root != NULL) { + ast_node_drop(ctx.root); + } + arena_drop(ctx.alloc); +} + +static void print(AstNode *node, uint32_t indent) { + const uint32_t I = 4; + switch (node->tag) { + case ATNumber: + fprintf(stderr, "%*sAstNumber(%.*s)\n", indent, "", node->number.token.span.len, node->number.token.lexeme); + break; + case ATIdent: + fprintf(stderr, "%*sAstIdent(%.*s)\n", indent, "", node->ident.token.span.len, node->ident.token.lexeme); + break; + case ATVersion: + fprintf(stderr, "%*sAstVersion:\n", indent, ""); + print((AstNode *)&node->version.version, indent + I); + break; + case ATNoSize: + fprintf(stderr, "%*sAstSize(none)\n", indent, ""); + break; + case ATMaxSize: + fprintf(stderr, "%*sAstSize(max):\n", indent, ""); + print((AstNode *)&node->size.value, indent + I); + break; + case ATFixedSize: + fprintf(stderr, "%*sAstSize(fixed):\n", indent, ""); + print((AstNode *)&node->size.value, indent + I); + break; + case ATHeapArray: + fprintf(stderr, "%*sAstArray(heap):\n", indent, ""); + print((AstNode *)node->array.type, indent + I); + print((AstNode *)&node->array.size, indent + I); + break; + case ATFieldArray: + fprintf(stderr, "%*sAstArray(field):\n", indent, ""); + print((AstNode *)node->array.type, indent + I); + print((AstNode *)&node->array.size, indent + I); + break; + case ATField: + fprintf(stderr, "%*sAstField(%.*s):\n", indent, "", node->field.name.span.len, node->field.name.lexeme); + print((AstNode *)&node->field.type, indent + I); + break; + case ATStruct: + fprintf(stderr, "%*sAstStruct(%.*s):\n", indent, "", node->struct_.ident.span.len, node->struct_.ident.lexeme); + for (size_t i = 0; i < node->struct_.fields.len; i++) { + print((AstNode *)&node->struct_.fields.data[i], indent + I); + } + break; + case ATMessage: + fprintf(stderr, "%*sAstMessage(%.*s):\n", indent, "", node->message.ident.span.len, node->message.ident.lexeme); + for (size_t i = 0; i < node->message.fields.len; i++) { + print((AstNode *)&node->message.fields.data[i], indent + I); + } + break; + case ATAttribute: + fprintf(stderr, "%*sAstAttribute(%.*s)\n", indent, "", node->attribute.ident.span.len, node->attribute.ident.lexeme); + break; + case ATMessages: + fprintf(stderr, "%*sAstMessages(%.*s):\n", indent, "", node->messages.name.span.len, node->messages.name.lexeme); + for (size_t i = 0; i < node->messages.children.len; i++) { + print((AstNode *)&node->messages.children.data[i], indent + I); + } + break; + case ATTypeDecl: + fprintf(stderr, "%*sAstTypeDecl(%.*s):\n", indent, "", node->type_decl.name.span.len, node->type_decl.name.lexeme); + print((AstNode *)&node->type_decl.value, indent + I); + break; + case ATConstant: + fprintf(stderr, "%*sAstConstant(%.*s):\n", indent, "", node->constant.name.span.len, node->constant.name.lexeme); + print((AstNode *)&node->constant.value, indent + I); + break; + case ATItems: + fprintf(stderr, "%*sAstItems:\n", indent, ""); + for (size_t i = 0; i < node->items.items.len; i++) { + print((AstNode *)&node->items.items.data[i], indent + I); + } + break; + } +} + +void ast_print(AstNode *node) { print(node, 0); } diff --git a/ser/ast.h b/ser/ast.h new file mode 100644 index 0000000..0236e19 --- /dev/null +++ b/ser/ast.h @@ -0,0 +1,329 @@ +#ifndef AST_H +#define AST_H +#include "arena_allocator.h" +#include "lexer.h" +#include "source.h" +#include "vector_impl.h" + +typedef enum { + ATNumber, + ATVersion, + ATIdent, + ATHeapArray, + ATFieldArray, + ATMaxSize, + ATFixedSize, + ATNoSize, + ATField, + ATAttribute, + ATStruct, + ATMessage, + ATMessages, + ATTypeDecl, + ATConstant, + ATItems, +} AstTag; + +typedef struct { + AstTag tag; + Span span; + Token token; +} AstNumber; + +typedef struct { + AstTag tag; + Span span; + Token token; +} AstIdent; + +typedef struct { + AstTag tag; + Span span; + AstNumber version; +} AstVersion; + +typedef struct { + AstTag tag; + Span span; + AstNumber value; +} AstSize; + +typedef struct { + AstTag tag; + Span span; + struct AstType *type; + AstSize size; +} AstArray; + +typedef union { + AstTag tag; + AstIdent ident; + AstArray array; +} AstType; + +typedef struct { + AstTag tag; + Span span; + Token name; + AstType type; +} AstField; + +VECTOR_IMPL(AstField, AstFieldVec, ast_field); + +typedef struct { + AstTag tag; + Span span; + Token ident; + AstFieldVec fields; +} AstStruct; + +typedef struct { + AstTag tag; + Span span; + Token ident; + AstFieldVec fields; +} AstMessage; + +typedef struct { + AstTag tag; + Span span; + Token ident; +} AstAttribute; + +typedef union { + AstTag tag; + AstMessage message; + AstAttribute attribute; +} AstAttributeOrMessage; + +VECTOR_IMPL(AstAttributeOrMessage, AstAttributeOrMessageVec, ast_attribute_or_message); + +typedef struct { + AstTag tag; + Span span; + Token name; + AstAttributeOrMessageVec children; +} AstMessages; + +typedef struct { + AstTag tag; + Span span; + Token name; + AstType value; +} AstTypeDecl; + +typedef struct { + AstTag tag; + Span span; + Token name; + AstNumber value; +} AstConstant; + +typedef union { + AstTag tag; + AstTypeDecl type_decl; + AstVersion version; + AstStruct struct_; + AstMessages messages; + AstConstant constant; +} AstItem; + +VECTOR_IMPL(AstItem, AstItemVec, ast_item); + +typedef struct { + AstTag tag; + Span span; + AstItemVec items; +} AstItems; + +typedef union { + AstTag tag; + AstNumber number; + AstIdent ident; + AstVersion version; + AstSize size; + AstArray array; + AstType type; + AstField field; + AstStruct struct_; + AstMessage message; + AstAttribute attribute; + AstAttributeOrMessage attribute_or_message; + AstMessages messages; + AstTypeDecl type_decl; + AstConstant constant; + AstItems items; +} AstNode; + +typedef struct { + AstNode *root; + ArenaAllocator alloc; +} AstContext; + +AstContext ast_init(); + +void ast_drop(AstContext ctx); + +static inline AstNumber ast_number(AstContext ctx, Span span, Token lit) { + AstNumber res; + res.tag = ATNumber; + res.span = span; + res.token = lit; + return res; +} + +static inline AstIdent ast_ident(AstContext ctx, Span span, Token ident) { + AstIdent res; + res.tag = ATIdent; + res.span = span; + res.token = ident; + return res; +} + +static inline AstVersion ast_version(AstContext ctx, Span span, AstNumber number) { + AstVersion res; + res.tag = ATVersion; + res.span = span; + res.version = number; + return res; +} + +static inline AstArray ast_heap_array(AstContext ctx, Span span, AstType *type, AstSize size) { + AstArray res; + res.tag = ATHeapArray; + res.span = span; + res.type = (struct AstType *)type; + res.size = size; + return res; +} + +static inline AstArray ast_field_array(AstContext ctx, Span span, AstType *type, AstSize size) { + AstArray res; + res.tag = ATFieldArray; + res.span = span; + res.type = (struct AstType *)type; + res.size = size; + return res; +} + +static inline AstSize ast_max_size(AstContext ctx, Span span, AstNumber size) { + AstSize res; + res.tag = ATMaxSize; + res.span = span; + res.value = size; + return res; +} + +static inline AstSize ast_fixed_size(AstContext ctx, Span span, AstNumber size) { + AstSize res; + res.tag = ATFixedSize; + res.span = span; + res.value = size; + return res; +} + +static inline AstSize ast_no_size(AstContext ctx, Span span) { + AstSize res; + res.tag = ATNoSize; + res.span = span; + return res; +} + +static inline AstField ast_field(AstContext ctx, Span span, Token name, AstType type) { + AstField res; + res.tag = ATField; + res.span = span; + res.name = name; + res.type = type; + return res; +} + +static inline AstStruct ast_struct(AstContext ctx, Span span, Token name, AstFieldVec fields) { + AstStruct res; + res.tag = ATStruct; + res.span = span; + res.ident = name; + res.fields = fields; + return res; +} + +static inline AstMessage ast_message(AstContext ctx, Span span, Token name, AstFieldVec fields) { + AstMessage res; + res.tag = ATMessage; + res.span = span; + res.ident = name; + res.fields = fields; + return res; +} + +static inline AstAttribute ast_attribute(AstContext ctx, Span span, Token attribute) { + AstAttribute res; + res.tag = ATAttribute; + res.span = span; + res.ident = attribute; + return res; +} + +static inline AstMessages ast_messages(AstContext ctx, Span span, Token name, AstAttributeOrMessageVec children) { + AstMessages res; + res.tag = ATMessages; + res.span = span; + res.name = name; + res.children = children; + return res; +} + +static inline AstConstant ast_constant(AstContext ctx, Span span, Token name, AstNumber value) { + AstConstant res; + res.tag = ATConstant; + res.span = span; + res.name = name; + res.value = value; + return res; +} + +static inline AstTypeDecl ast_type_decl(AstContext ctx, Span span, Token name, AstType type) { + AstTypeDecl res; + res.tag = ATTypeDecl; + res.span = span; + res.name = name; + res.value = type; + return res; +} + +static inline AstItems ast_items(AstContext ctx, Span span, AstItemVec items) { + AstItems res; + res.tag = ATItems; + res.span = span; + res.items = items; + return res; +} + +void ast_print(AstNode *node); + +static inline const char *ast_tag_to_string(AstTag tag) { +#define _case(c) \ + case AT##c: \ + return #c + switch (tag) { + _case(Number); + _case(Version); + _case(Ident); + _case(HeapArray); + _case(FieldArray); + _case(MaxSize); + _case(FixedSize); + _case(NoSize); + _case(Field); + _case(Struct); + _case(Message); + _case(Attribute); + _case(Messages); + _case(TypeDecl); + _case(Constant); + _case(Items); + } +#undef _case +} + +#endif diff --git a/ser/codegen.c b/ser/codegen.c new file mode 100644 index 0000000..467e168 --- /dev/null +++ b/ser/codegen.c @@ -0,0 +1,193 @@ +#include "codegen.h" + +#include +#include + +static void buffered_writer_write(void *w, const char *data, size_t len) { + // We don't use vec_push array because we want the string to be null terminated at all time (while not really including the + // null character in the string / len) + BufferedWriter *bw = (BufferedWriter *)w; + vec_grow(&bw->buf, bw->buf.len + len + 1); + memcpy(&bw->buf.data[bw->buf.len], data, len); + bw->buf.data[bw->buf.len + len] = '\0'; + bw->buf.len += len; +} +static void buffered_writer_format(void *w, const char *fmt, va_list args) { + BufferedWriter *bw = (BufferedWriter *)w; + va_list args2; + va_copy(args2, args); + size_t cap = bw->buf.cap - bw->buf.len; + char *ptr = &bw->buf.data[bw->buf.len]; + int len = vsnprintf(ptr, cap, fmt, args); + if (cap <= len) { + // The writing failed + vec_grow(&bw->buf, bw->buf.len + len + 1); + ptr = &bw->buf.data[bw->buf.len]; + vsnprintf(ptr, len + 1, fmt, args2); + } + va_end(args2); + bw->buf.len += len; +} +static void file_writer_write(void *w, const char *data, size_t len) { + FileWriter *fw = (FileWriter *)w; + fwrite(data, 1, len, fw->fd); +} +static void file_writer_format(void *w, const char *fmt, va_list args) { + FileWriter *fw = (FileWriter *)w; + vfprintf(fw->fd, fmt, args); +} +static void null_writer_write(void *w, const char *data, size_t len) { } +static void null_writer_format(void *w, const char *fmt, va_list args) { } + +BufferedWriter buffered_writer_init() { + CharVec buf = vec_init(); + vec_grow(&buf, 512); + return (BufferedWriter){.w.write = buffered_writer_write, .w.format = buffered_writer_format, .buf = buf}; +} +void buffered_writer_drop(BufferedWriter w) { vec_drop(w.buf); } +FileWriter file_writer_init(const char *path) { + FILE *fd = fopen(path, "w"); + assert(fd != NULL, "couldn't open output file"); + return (FileWriter){.w.write = file_writer_write, .w.format = file_writer_format, .fd = fd}; +} +FileWriter file_writer_from_fd(FILE *fd) { + return (FileWriter){.w.write = file_writer_write, .w.format = file_writer_format, .fd = fd}; +} +void file_writer_drop(FileWriter w) { fclose(w.fd); } +NullWriter null_writer_init() { + return (NullWriter){.w.write = null_writer_write, .w.format = null_writer_format}; +} + +void wt_write(Writer *w, const char *data, size_t len) { w->write(w, data, len); } +void wt_format(Writer *w, const char *fmt, ...) { + va_list args; + va_start(args, fmt); + w->format(w, fmt, args); + va_end(args); +} + +typedef struct { + StructObject *obj; + PointerVec dependencies; +} StructDependencies; + +void strdeps_drop(void *item) { + StructDependencies *deps = (StructDependencies *)item; + vec_drop(deps->dependencies); +} + +impl_hashmap( + strdeps, StructDependencies, { return hash(state, (byte *)&v->obj, sizeof(StructObject *)); }, { return a->obj == b->obj; } +); + +int struct_object_compare(const void *a, const void *b) { + const StructObject *sa = *(StructObject **)a; + const StructObject *sb = *(StructObject **)b; + size_t len = sa->name.len < sb->name.len ? sa->name.len : sb->name.len; + return strncmp(sa->name.ptr, sb->name.ptr, len); +} + +void define_structs(Program *p, Writer *w, void (*define)(Writer *w, StructObject *obj, void *), void * user_data) { + Hashmap *dependencies = hashmap_init(strdeps_hash, strdeps_equal, strdeps_drop, sizeof(StructDependencies)); + + TypeDef *td = NULL; + while (hashmap_iter(p->typedefs, &td)) { + if (td->value->kind != TypeStruct) + continue; + StructObject *obj = (StructObject *)&td->value->type.struct_; + + StructDependencies deps = {.obj = obj, .dependencies = vec_init()}; + for (size_t i = 0; i < obj->fields.len; i++) { + TypeObject *type = obj->fields.data[i].type; + // Skip through the field arrays + while (type->kind == TypeArray && !type->type.array.heap) { + type = type->type.array.type; + } + + if (type->kind != TypeStruct) + continue; + + vec_push(&deps.dependencies, &type->type.struct_); + } + + hashmap_set(dependencies, &deps); + } + + PointerVec to_define = vec_init(); + size_t pass = 0; + do { + vec_clear(&to_define); + + StructDependencies *deps = NULL; + while (hashmap_iter(dependencies, &deps)) { + bool dependencies_met = true; + for (size_t i = 0; i < deps->dependencies.len; i++) { + if (hashmap_has(dependencies, &(StructDependencies){.obj = deps->dependencies.data[i]})) { + dependencies_met = false; + break; + } + } + + if (!dependencies_met) + continue; + vec_push(&to_define, deps->obj); + } + + qsort(to_define.data, to_define.len, sizeof(StructObject *), struct_object_compare); + + for (size_t i = 0; i < to_define.len; i++) { + StructObject *s = to_define.data[i]; + define(w, s, user_data); + hashmap_delete(dependencies, &(StructDependencies){.obj = s}); + } + pass++; + } while (to_define.len > 0); + + if (dependencies->count > 0) { + log_error("cyclic struct dependency without indirection couldn't be resolved"); + } + + hashmap_drop(dependencies); + vec_drop(to_define); +} + +char *pascal_to_snake_case(StringSlice str) { + CharVec res = vec_init(); + vec_grow(&res, str.len + 4); + bool was_upper = false; + for (size_t i = 0; i < str.len; i++) { + if (i == 0) { + vec_push(&res, tolower(str.ptr[i])); + continue; + } + + char c = str.ptr[i]; + if (isupper(c) && !was_upper) { + vec_push(&res, '_'); + } + + was_upper = isupper(c); + + vec_push(&res, tolower(c)); + } + + vec_push(&res, '\0'); + + return res.data; +} + +char *snake_case_to_screaming_snake_case(StringSlice str) { + CharVec res = vec_init(); + vec_grow(&res, str.len + 4); + for(size_t i = 0; i < str.len; i++) { + char c = str.ptr[i]; + if('a' <= c && c <= 'z') { + vec_push(&res, c - 'a' + 'A'); + } else { + vec_push(&res, c); + } + } + vec_push(&res, '\0'); + + return res.data; +} diff --git a/ser/codegen.h b/ser/codegen.h new file mode 100644 index 0000000..b635215 --- /dev/null +++ b/ser/codegen.h @@ -0,0 +1,73 @@ +#ifndef CODEGEN_H +#define CODEGEN_H +#include "eval.h" +#include "vector.h" + +#include +#include +#include +#include + +#define INDENT 4 + +#define MSG_MAGIC_START 0xCAFEF00DBEEFDEADUL +#define MSG_MAGIC_END 0xF00DBEEFCAFEDEADUL + +// Struct used to define the relative alignment when working with structs +typedef struct { + Alignment align; + uint8_t offset; +} CurrentAlignment; + +typedef struct { + void (*write)(void *w, const char *data, size_t len); + void (*format)(void *w, const char *fmt, va_list args); +} Writer; + +typedef struct { + Writer w; + CharVec buf; +} BufferedWriter; + +typedef struct { + Writer w; + FILE *fd; +} FileWriter; + +typedef struct { + Writer w; +} NullWriter; + +BufferedWriter buffered_writer_init(); +void buffered_writer_drop(BufferedWriter w); +FileWriter file_writer_init(const char *path); +FileWriter file_writer_from_fd(FILE *fd); +void file_writer_drop(FileWriter w); +NullWriter null_writer_init(); + +void wt_write(Writer *w, const char *data, size_t len); +void wt_format(Writer *w, const char *fmt, ...); + +// Define the structs of a program in the correct order (respecting direct dependencies) +void define_structs(Program *p, Writer *w, void (*define)(Writer *w, StructObject *, void *), void *user_data); +char *pascal_to_snake_case(StringSlice str); +char *snake_case_to_screaming_snake_case(StringSlice str); + +// Check if c is aligned to alignment to +static inline bool calign_is_aligned(CurrentAlignment c, Alignment to) { + assert(to.value <= c.align.value, "Can't know if calign is aligned to aligment if major alignment is less"); + return (c.offset & to.mask) == 0; +} +// Add offset to the offset of c +static inline CurrentAlignment calign_add(CurrentAlignment c, uint8_t offset) { + c.offset += offset; + c.offset &= c.align.mask; + return c; +} +// Compute the number of bytes of padding needed to be aligned to a from c. +static inline uint8_t calign_to(CurrentAlignment c, Alignment a) { + assert(a.value <= c.align.value, "Can't align when major alignment is less than requested alignment"); + return (-c.offset) & a.mask; +} + +#endif diff --git a/ser/codegen_c.c b/ser/codegen_c.c new file mode 100644 index 0000000..a4d4d0c --- /dev/null +++ b/ser/codegen_c.c @@ -0,0 +1,878 @@ +#include "codegen_c.h" + +#include "vector.h" +#include "vector_impl.h" + +#include + +typedef enum { + MTPointer, + MTArray, +} ModifierType; + +typedef struct { + ModifierType type; + uint64_t size; +} Modifier; + +#define MOD_PTR ((Modifier){.type = MTPointer}) +#define MOD_ARRAY(s) ((Modifier){.type = MTArray, .size = s}) + +VECTOR_IMPL(Modifier, ModifierVec, modifier); + +static inline const char *array_size_type(uint64_t size) { + if (size <= UINT8_MAX) { + return "uint8_t"; + } else if (size <= UINT16_MAX) { + return "uint16_t"; + } else if (size <= UINT32_MAX) { + return "uint32_t"; + } else { + return "uint64_t"; + } +} + +static void write_field(Writer *w, Field f, Modifier *mods, size_t len, uint32_t indent); +// Wrte the *base* type type with indentation +static void write_type(Writer *w, TypeObject *type, uint32_t indent) { + if (type->kind == TypePrimitif) { +#define _case(x, s) \ + case Primitif_##x: \ + wt_format(w, "%*s" #s " ", indent, ""); \ + break + switch (type->type.primitif) { + _case(u8, uint8_t); + _case(u16, uint16_t); + _case(u32, uint32_t); + _case(u64, uint64_t); + _case(i8, uint8_t); + _case(i16, uint16_t); + _case(i32, uint32_t); + _case(i64, uint64_t); + _case(f32, float); + _case(f64, double); + _case(char, char); + _case(bool, bool); + } +#undef _case + } else if (type->kind == TypeStruct) { + wt_format(w, "%*sstruct %.*s ", indent, "", type->type.struct_.name.len, type->type.struct_.name.ptr); + } else { + if (type->type.array.sizing == SizingMax) { + const char *len_type = array_size_type(type->type.array.size); + wt_format(w, "%*sstruct {\n%*s%s len;\n", indent, "", indent + INDENT, "", len_type); + Field f = {.name = STRING_SLICE("data"), .type = type->type.array.type}; + Modifier mod; + if (type->type.array.heap) { + mod = MOD_PTR; + } else { + mod = MOD_ARRAY(type->type.array.size); + } + write_field(w, f, &mod, 1, indent + INDENT); + wt_format(w, ";\n%*s} ", indent, ""); + } else { + log_error("Called write_type on a non base type"); + } + } +} + +// Algorithm to handle c types here: +// 0. Given a field with a name and a type. +// let a base type be a type without modifiers (any type that isn't an array with fixed sizing). +// let modifiers be a sequence of array or pointer (array have size), the type of a modifier is +// either array or pointer. +// let * be a pointer modifier and [n] be an array modifier of size n +// 1. initialize a list of modifers, a base type variable, +// current = type (of the field) +// while(is_array(current) && is_fixed_size(current)) do +// if is_heap(current) then +// push(modifiers, *) +// else +// push(modifiers, [size(current)]) +// end +// current = element_type(current) +// end +// base_type = current +// 3. we now have a base type, a list of modifiers and a field name. +// 4. emit base_type " " +// 5. walk modifiers in reverse (let m and p be the current and previous modifiers) +// if type(m) == pointer then +// if type(m) != type(p) then +// emit "(" +// end +// emit "*" +// end +// 6. emit field_name +// 7. walk modifiers forward (let m and p be the current and previous modifiers) +// if type(m) == array then +// if type(m) != type(p) then +// emit "(" +// end +// emit "[" size(m) "]" +// end +// 8. Examples +// given the field: +// foo: char&[1][2][3]&[4][5]&[6]&[7], +// 3. +// base_type = char +// modifiers = {*, *, [5], *, [3], [2], *} +// field_name = foo +// result = "" +// 4. +// result = "char " +// 5. +// result = "char *(*(**" +// 6. +// result = "char *(*(**foo" +// 7. +// result = "char *(*(**foo)[5])[3][2]" +// +// in a trivial case the algoritm works as expected: +// bar: char, +// 3. +// base_type = char +// modifiers = {} +// field_name = bar +// result = "" +// 4. +// result = "char " +// 5. +// result = "char " +// 6. +// result = "char bar" +// 7. +// result = "char bar" + +static void write_field(Writer *w, Field f, Modifier *mods, size_t len, uint32_t indent) { + TypeObject *type = f.type; + ModifierVec modifiers = vec_init(); + TypeObject *current = type; + _vec_modifier_push_array(&modifiers, mods, len); + while (current->kind == TypeArray && current->type.array.sizing == SizingFixed) { + if (current->type.array.heap) { + _vec_modifier_push(&modifiers, MOD_PTR); + } else { + _vec_modifier_push(&modifiers, MOD_ARRAY(current->type.array.size)); + } + current = current->type.array.type; + } + TypeObject *base_type = current; + + write_type(w, base_type, indent); + for (int i = modifiers.len - 1; i >= 0; i--) { + Modifier m = modifiers.data[i]; + if (m.type != MTPointer) + continue; + if (i == modifiers.len - 1 || modifiers.data[i + 1].type == m.type) { + wt_format(w, "*"); + } else { + wt_format(w, "(*"); + } + } + wt_format(w, "%.*s", f.name.len, f.name.ptr); + for (size_t i = 0; i < modifiers.len; i++) { + Modifier m = modifiers.data[i]; + if (m.type != MTArray) + continue; + if (i == 0 || modifiers.data[i - 1].type == m.type) { + wt_format(w, "[%lu]", m.size); + } else { + wt_format(w, ")[%lu]", m.size); + } + } + _vec_modifier_drop(modifiers); +} + +static void write_struct(Writer *w, StructObject *obj, void *_user_data) { + wt_format(w, "typedef struct %.*s {\n", obj->name.len, obj->name.ptr); + + for (size_t i = 0; i < obj->fields.len; i++) { + Field f = obj->fields.data[i]; + write_field(w, f, NULL, 0, INDENT); + wt_format(w, ";\n", f.name.len, f.name.ptr); + } + + wt_format(w, "} %.*s;\n\n", obj->name.len, obj->name.ptr); +} + +static void write_align(Writer *w, const char *var, const Alignment align, size_t indent) { + wt_format(w, "%*s%s = (byte*)(((((uintptr_t)%s - 1) >> %u) + 1) << %u);\n", indent, "", var, var, align.po2, align.po2); +} + +static void write_accessor(Writer *w, TypeObject *base_type, FieldAccessor fa, bool ptr) { + if (fa.indices.len == 0) + return; + + if (ptr) { + wt_write(w, "->", 2); + } else { + wt_write(w, ".", 1); + } + + TypeObject *t = base_type; + for (size_t j = 0; j < fa.indices.len; j++) { + uint64_t index = fa.indices.data[j]; + + if (t->kind == TypeStruct) { + if (j != 0) + wt_write(w, ".", 1); + StructObject *st = (StructObject *)&t->type.struct_; + wt_write(w, st->fields.data[index].name.ptr, st->fields.data[index].name.len); + t = st->fields.data[index].type; + } else if (t->kind == TypeArray) { + if (t->type.array.sizing == SizingMax) { + if (j != 0) + wt_write(w, ".", 1); + if (index == 0) { + uint64_t size = t->type.array.size; + wt_write(w, "len", 3); + const TypeObject *type; + if (size <= UINT8_MAX) { + type = &PRIMITIF_u8; + } else if (size <= UINT16_MAX) { + type = &PRIMITIF_u16; + } else if (size <= UINT32_MAX) { + type = &PRIMITIF_u32; + } else { + type = &PRIMITIF_u64; + } + t = (TypeObject *)type; + } else { + wt_write(w, "data", 4); + t = t->type.array.type; + } + } else { + wt_format(w, "[%lu]", index); + t = t->type.array.type; + } + } + } +} + +static bool is_field_accessor_heap_array(FieldAccessor fa, TypeObject *base_type) { + if (fa.indices.len == 0) + return base_type->kind == TypeArray && base_type->type.array.heap; + + // In the case of a heap array the last index will choose between length and data, + // but since we only care about the array + fa.indices.len--; + + TypeObject *t = base_type; + for (size_t i = 0; i < fa.indices.len; i++) { + uint64_t index = fa.indices.data[i]; + + if (t->kind == TypeStruct) { + StructObject *st = (StructObject *)&t->type.struct_; + t = st->fields.data[index].type; + } else if (t->kind == TypeArray) { + if (t->type.array.sizing == SizingMax) { + if (index == 0) { + return false; + } else { + t = t->type.array.type; + } + } else { + t = t->type.array.type; + } + } + } + + return t->kind == TypeArray && t->type.array.heap; +} + +static void write_type_serialization( + Writer *w, const char *base, bool ptr, Layout *layout, CurrentAlignment al, Hashmap *layouts, size_t indent, size_t depth, bool always_inline +) { + if (layout->fields.len == 0) + return; + + Alignment align = al.align; + size_t offset = al.offset; + + offset += calign_to(al, layout->fields.data[0].type->align); + + if (layout->type->kind == TypeStruct && layout->type->type.struct_.has_funcs && !always_inline) { + char *name = pascal_to_snake_case(layout->type->type.struct_.name); + char *deref = ptr ? "*" : ""; + wt_format(w, "%*sbuf += %s_serialize(%s%s, &buf[%lu]);\n", indent, "", name, deref, base, offset); + free(name); + return; + } + + size_t i = 0; + for (; i < layout->fields.len && layout->fields.data[i].size != 0; i++) { + FieldAccessor fa = layout->fields.data[i]; + wt_format(w, "%*s*(", indent, ""); + write_type(w, fa.type, 0); + wt_format(w, "*)&buf[%lu] = %s", offset, base); + write_accessor(w, layout->type, fa, ptr); + wt_write(w, ";\n", 2); + + offset += fa.size; + al = calign_add(al, fa.size); + } + + if (i < layout->fields.len) { + offset += calign_to(al, layout->fields.data[i].type->align); + wt_format(w, "%*sbuf += %lu;\n", indent, "", offset); + + for (; i < layout->fields.len; i++) { + FieldAccessor farr = layout->fields.data[i]; + FieldAccessor flen = field_accessor_clone(&farr); + // Access the length instead of data + flen.indices.data[flen.indices.len - 1] = 0; + + wt_format(w, "%*sfor(size_t i = 0; i < %s", indent, "", base); + write_accessor(w, layout->type, flen, ptr); + field_accessor_drop(flen); + char *vname = msprintf("e%lu", depth); + wt_format(w, "; i++) {\n%*stypeof(%s", indent + INDENT, "", base); + write_accessor(w, layout->type, farr, ptr); + wt_format(w, "[i]) %s = %s", vname, base); + write_accessor(w, layout->type, farr, ptr); + wt_format(w, "[i];\n"); + + Layout *arr_layout = hashmap_get(layouts, &(Layout){.type = farr.type}); + assert(arr_layout != NULL, "Type has no layout (How ?)"); + write_type_serialization( + w, + vname, + false, + arr_layout, + (CurrentAlignment){.align = farr.type->align, .offset = 0}, + layouts, + indent + INDENT, + depth + 1, + false + ); + wt_format(w, "%*s}\n", indent, ""); + free(vname); + } + write_align(w, "buf", align, indent); + } else { + offset += calign_to(al, align); + wt_format(w, "%*sbuf += %lu;\n", indent, "", offset); + } +} + +static void write_type_deserialization( + Writer *w, const char *base, bool ptr, Layout *layout, CurrentAlignment al, Hashmap *layouts, size_t indent, size_t depth, bool always_inline +) { + if (layout->fields.len == 0) + return; + + Alignment align = al.align; + size_t offset = al.offset; + + offset += calign_to(al, layout->fields.data[0].type->align); + + if (layout->type->kind == TypeStruct && layout->type->type.struct_.has_funcs && !always_inline) { + char *name = pascal_to_snake_case(layout->type->type.struct_.name); + char *ref = ptr ? "" : "&"; + wt_format(w, "%*sbuf += %s_deserialize(%s%s, &buf[%lu]);\n", indent, "", name, ref, base, offset); + free(name); + return; + } + + char *deref = ""; + if (layout->type->kind == TypePrimitif) { + deref = "*"; + } + + size_t i = 0; + for (; i < layout->fields.len && layout->fields.data[i].size != 0; i++) { + FieldAccessor fa = layout->fields.data[i]; + wt_format(w, "%*s%s%s", indent, "", deref, base); + write_accessor(w, layout->type, fa, ptr); + wt_format(w, " = *("); + write_type(w, fa.type, 0); + wt_format(w, "*)&buf[%lu]", offset, base); + wt_write(w, ";\n", 2); + + offset += fa.size; + al = calign_add(al, fa.size); + } + + if (i < layout->fields.len) { + offset += calign_to(al, layout->fields.data[i].type->align); + wt_format(w, "%*sbuf += %lu;\n", indent, "", offset); + + for (; i < layout->fields.len; i++) { + FieldAccessor farr = layout->fields.data[i]; + FieldAccessor flen = field_accessor_clone(&farr); + // Access the length instead of data + flen.indices.data[flen.indices.len - 1] = 0; + + if (is_field_accessor_heap_array(farr, layout->type)) { + wt_format(w, "%*s%s", indent, "", base); + write_accessor(w, layout->type, farr, ptr); + wt_format(w, " = malloc(%s", base); + write_accessor(w, layout->type, flen, ptr); + wt_format(w, " * sizeof(typeof(*%s", base); + write_accessor(w, layout->type, farr, ptr); + wt_format(w, ")));\n"); + } + wt_format(w, "%*sfor(size_t i = 0; i < %s", indent, "", base); + write_accessor(w, layout->type, flen, ptr); + field_accessor_drop(flen); + char *vname = msprintf("e%lu", depth); + wt_format(w, "; i++) {\n%*stypeof(&%s", indent + INDENT, "", base); + write_accessor(w, layout->type, farr, ptr); + wt_format(w, "[i]) %s = &%s", vname, base); + write_accessor(w, layout->type, farr, ptr); + wt_format(w, "[i];\n"); + + Layout *arr_layout = hashmap_get(layouts, &(Layout){.type = farr.type}); + assert(arr_layout != NULL, "Type has no layout (How ?)"); + write_type_deserialization( + w, + vname, + true, + arr_layout, + (CurrentAlignment){.align = farr.type->align, .offset = 0}, + layouts, + indent + INDENT, + depth + 1, + false + ); + wt_format(w, "%*s}\n", indent, ""); + free(vname); + } + write_align(w, "buf", align, indent); + } else { + offset += calign_to(al, align); + wt_format(w, "%*sbuf += %lu;\n", indent, "", offset); + } +} + +static int write_type_free(Writer *w, const char *base, TypeObject *type, Hashmap *layouts, size_t indent, size_t depth, bool always_inline) { + if (type->kind == TypePrimitif) { + return 0; + } else if (type->kind == TypeArray) { + BufferedWriter b = buffered_writer_init(); + Writer *w2 = (Writer *)&b; + + int total = 0; + + wt_format(w2, "%*sfor(size_t i = 0; i < ", indent, ""); + if (type->type.array.sizing == SizingMax) { + wt_format(w2, "%s.len; i++) {\n", base); + wt_format(w2, "%*stypeof(%s.data[i]) e%lu = %s.data[i];\n", indent + INDENT, "", base, depth, base); + } else { + wt_format(w2, "%lu; i++) {\n", type->type.array.size); + wt_format(w2, "%*stypeof(%s[i]) e%lu = %s[i];\n", indent + INDENT, "", base, depth, base); + } + + char *new_base = msprintf("e%lu", depth); + total += write_type_free(w2, new_base, type->type.array.type, layouts, indent + INDENT, depth + 1, false); + free(new_base); + wt_format(w2, "%*s}\n", indent, ""); + + if (total > 0) { + wt_write(w, b.buf.data, b.buf.len); + } + buffered_writer_drop(b); + + if (type->type.array.heap) { + wt_format(w, "%*sfree(%s.data);\n", indent, "", base); + total++; + } + + return total; + } else if (type->kind == TypeStruct) { + StructObject *s = (StructObject *)&type->type.struct_; + int total = 0; + + if(type->type.struct_.has_funcs && !always_inline) { + char *name = pascal_to_snake_case(s->name); + wt_format(w, "%*s%s_free(%s);\n", indent, "", name, base); + free(name); + + Layout *layout = hashmap_get(layouts, &(Layout){.type = type}); + assert(layout != NULL, "No layout for type that has funcs defined"); + for(size_t i = 0; i < layout->fields.len; i++) { + if(layout->fields.data[i].size == 0) total++; + } + + return total; + } + + for (size_t i = 0; i < s->fields.len; i++) { + Field f = s->fields.data[i]; + char *new_base = msprintf("%s.%.*s", base, f.name.len, f.name.ptr); + total += write_type_free(w, new_base, f.type, layouts, indent, depth, false); + free(new_base); + } + + return total; + } + + return 0; +} + +static void write_struct_func_decl(Writer *w, StructObject *obj, void *_user_data) { + obj->has_funcs = true; + + StringSlice sname = obj->name; + char *snake_case_name = pascal_to_snake_case(sname); + wt_format(w, "__attribute__((unused)) static int %s_serialize(struct %.*s val, byte *buf);\n", snake_case_name, sname.len, sname.ptr); + wt_format(w, "__attribute__((unused)) static int %s_deserialize(struct %.*s *val, const byte *buf);\n", snake_case_name, sname.len, sname.ptr); + wt_format(w, "__attribute__((unused)) static void %s_free(struct %.*s val);\n", snake_case_name, sname.len, sname.ptr); + free(snake_case_name); +} + +static void write_struct_func(Writer *w, StructObject *obj, void *user_data) { + Hashmap *layouts = user_data; + // Retreive original TypeObject pointer from struct object pointer. + TypeObject *t = (void *)((byte *)obj - offsetof(struct TypeObject, type)); + + Layout *layout = hashmap_get(layouts, &(Layout){.type = t}); + assert(layout != NULL, "No layout found for struct"); + + StringSlice sname = obj->name; + char *snake_case_name = pascal_to_snake_case(sname); + wt_format(w, "static int %s_serialize(struct %.*s val, byte *buf) {\n", snake_case_name, sname.len, sname.ptr); + wt_format(w, "%*sbyte * base_buf = buf;\n", INDENT, ""); + write_type_serialization(w, "val", false, layout, (CurrentAlignment){.offset = 0, .align = t->align}, layouts, INDENT, 0, true); + wt_format(w, "%*sreturn (int)(buf - base_buf);\n", INDENT, ""); + wt_format(w, "}\n"); + + wt_format(w, "static int %s_deserialize(struct %.*s *val, const byte *buf) {\n", snake_case_name, sname.len, sname.ptr); + wt_format(w, "%*sconst byte * base_buf = buf;\n", INDENT, ""); + write_type_deserialization(w, "val", true, layout, (CurrentAlignment){.offset = 0, .align = t->align}, layouts, INDENT, 0, true); + wt_format(w, "%*sreturn (int)(buf - base_buf);\n", INDENT, ""); + wt_format(w, "}\n"); + + wt_format(w, "static void %s_free(struct %.*s val) {", snake_case_name, sname.len, sname.ptr); + BufferedWriter b = buffered_writer_init(); + int f = write_type_free((Writer*)&b, "val", t, layouts, INDENT, 0, true); + if(f > 0) { + wt_format(w, "\n%.*s}\n\n", b.buf.len, b.buf.data); + } else { + wt_format(w, " }\n\n"); + } + buffered_writer_drop(b); + + free(snake_case_name); +} + +void codegen_c(Writer *header, Writer *source, const char *name, Program *p) { + char *uc_name = snake_case_to_screaming_snake_case((StringSlice){.ptr = name, .len = strlen(name)}); + wt_format( + header, + "// Generated file\n" + "#ifndef %s_H\n" + "#define %s_H\n" + "#include \n" + "#include \n" + "#include \n" + "\n" + "typedef unsigned char byte;\n" + "typedef uint64_t MsgMagic;\n" + "\n" + "#define MSG_MAGIC_SIZE sizeof(MsgMagic)\n" + "static const MsgMagic MSG_MAGIC_START = 0x%016lX;\n" + "static const MsgMagic MSG_MAGIC_END = 0x%016lX;\n" + "\n", + uc_name, + uc_name, + MSG_MAGIC_START, + MSG_MAGIC_END + ); + free(uc_name); + wt_format( + source, + "// Generated file\n" + "#include \"%s.h\"\n" + "#include \n" + "\n", + name + ); + + define_structs(p, header, write_struct, NULL); + define_structs(p, source, write_struct_func_decl, NULL); + wt_format(source, "\n"); + define_structs(p, source, write_struct_func, p->layouts); + + for (size_t i = 0; i < p->messages.len; i++) { + MessagesObject msgs = p->messages.data[i]; + + wt_format(header, "// %.*s\n\n", msgs.name.len, msgs.name.ptr); + wt_format( + header, + "typedef enum %.*sTag {\n%*s%.*sTagNone = 0,\n", + msgs.name.len, + msgs.name.ptr, + INDENT, + "", + msgs.name.len, + msgs.name.ptr + ); + for (size_t j = 0; j < msgs.messages.len; j++) { + wt_format( + header, + "%*s%.*sTag%.*s = %lu,\n", + INDENT, + "", + msgs.name.len, + msgs.name.ptr, + msgs.messages.data[j].name.len, + msgs.messages.data[j].name.ptr, + j + 1 + ); + } + wt_format(header, "} %.*sTag;\n\n", msgs.name.len, msgs.name.ptr); + + for (size_t j = 0; j < msgs.messages.len; j++) { + MessageObject msg = msgs.messages.data[j]; + + wt_format(header, "typedef struct %.*s%.*s {\n", msgs.name.len, msgs.name.ptr, msg.name.len, msg.name.ptr); + wt_format(header, "%*s%.*sTag tag;\n", INDENT, "", msgs.name.len, msgs.name.ptr); + + for (size_t k = 0; k < msg.fields.len; k++) { + Field f = msg.fields.data[k]; + write_field(header, f, NULL, 0, INDENT); + wt_format(header, ";\n"); + } + if (msg.attributes & Attr_versioned) { + write_field( + header, + (Field){.name.ptr = "_version", .name.len = 8, .type = (TypeObject *)&PRIMITIF_u64}, + NULL, + 0, + INDENT + ); + wt_format(header, ";\n"); + } + + wt_format(header, "} %.*s%.*s;\n\n", msgs.name.len, msgs.name.ptr, msg.name.len, msg.name.ptr); + } + + wt_format(header, "typedef union %.*sMessage {\n", msgs.name.len, msgs.name.ptr); + wt_format(header, "%*s%.*sTag tag;\n", INDENT, "", msgs.name.len, msgs.name.ptr); + for (size_t j = 0; j < msgs.messages.len; j++) { + MessageObject msg = msgs.messages.data[j]; + char *field = pascal_to_snake_case(msg.name); + wt_format(header, "%*s%.*s%.*s %s;\n", INDENT, "", msgs.name.len, msgs.name.ptr, msg.name.len, msg.name.ptr, field); + free(field); + } + wt_format(header, "} %.*sMessage;\n\n", msgs.name.len, msgs.name.ptr); + + char *name = pascal_to_snake_case(msgs.name); + wt_format( + header, + "// Serialize the message msg to buffer dst of size len, returns the length of the serialized message, or -1 on " + "error (buffer overflow)\n" + ); + wt_format(header, "int msg_%s_serialize(byte *dst, size_t len, %.*sMessage *msg);\n", name, msgs.name.len, msgs.name.ptr); + wt_format( + header, + "// Deserialize the message in the buffer src of size len into dst, return the length of the serialized message or " + "-1 on error.\n" + ); + wt_format( + header, + "int msg_%s_deserialize(const byte *src, size_t len, %.*sMessage *dst);\n\n", + name, + msgs.name.len, + msgs.name.ptr + ); + wt_format( + header, + "// Free the message (created by msg_%s_deserialize)\n" + "void msg_%s_free(%.*sMessage *msg);\n", + name, + name, + msgs.name.len, + msgs.name.ptr + ); + + char *tag_type = msprintf("%.*sTag", msgs.name.len, msgs.name.ptr); + PointerVec message_tos = vec_init(); + + for (size_t j = 0; j < msgs.messages.len; j++) { + MessageObject m = msgs.messages.data[j]; + TypeObject *to = malloc(sizeof(TypeObject)); + assert_alloc(to); + { + StructObject obj = {.name = m.name, .fields = vec_clone(&m.fields)}; + if (m.attributes & Attr_versioned) { + vec_push(&obj.fields, ((Field){.name.ptr = "_version", .name.len = 8, .type = (TypeObject *)&PRIMITIF_u64})); + } + to->type.struct_ = *(struct StructObject *)&obj; + to->kind = TypeStruct; + to->align = ALIGN_8; + } + Layout layout = type_layout(to); + vec_push(&message_tos, to); + + hashmap_set(p->layouts, &layout); + } + + { + wt_format( + source, + "int msg_%s_serialize(byte *buf, size_t len, %.*sMessage *msg) {\n", + name, + msgs.name.len, + msgs.name.ptr + ); + + wt_format(source, "%*sconst byte *base_buf = buf;\n", INDENT, ""); + wt_format(source, "%*sif(len < 2 * MSG_MAGIC_SIZE)\n", INDENT, ""); + wt_format(source, "%*sreturn -1;\n", INDENT * 2, ""); + wt_format(source, "%*s*(MsgMagic*)buf = MSG_MAGIC_START;\n", INDENT, ""); + wt_format(source, "%*sbuf += MSG_MAGIC_SIZE;\n", INDENT, ""); + wt_format(source, "%*sswitch(msg->tag) {\n", INDENT, ""); + wt_format(source, "%*scase %sNone:\n%*sbreak;\n", INDENT, "", tag_type, INDENT * 2, ""); + + for (size_t j = 0; j < msgs.messages.len; j++) { + MessageObject m = msgs.messages.data[j]; + TypeObject *mtype = message_tos.data[j]; + Layout *layout = hashmap_get(p->layouts, &(Layout){.type = mtype}); + assert(layout != NULL, "What ?"); + char *snake_case_name = pascal_to_snake_case(m.name); + char *base = msprintf("msg->%s", snake_case_name); + + wt_format(source, "%*scase %s%.*s: {\n", INDENT, "", tag_type, m.name.len, m.name.ptr); + wt_format(source, "%*s*(uint16_t *)buf = %s%.*s;\n", INDENT * 2, "", tag_type, m.name.len, m.name.ptr); + if (m.attributes & Attr_versioned) { + wt_format(source, "%*smsg->%s._version = %luUL;\n", INDENT * 2, "", snake_case_name, msgs.version); + } + write_type_serialization( + source, + base, + false, + layout, + (CurrentAlignment){.align = ALIGN_8, .offset = 2}, + p->layouts, + INDENT * 2, + 0, + false + ); + wt_format(source, "%*sbreak;\n%*s}\n", INDENT * 2, "", INDENT, ""); + + free(base); + free(snake_case_name); + } + wt_format(source, "%*s}\n", INDENT, ""); + wt_format(source, "%*s*(MsgMagic*)buf = MSG_MAGIC_END;\n", INDENT, ""); + wt_format(source, "%*sbuf += MSG_MAGIC_SIZE;\n", INDENT, ""); + wt_format(source, "%*sif(buf > base_buf + len)\n", INDENT, ""); + wt_format(source, "%*sreturn -1;\n", INDENT * 2, ""); + wt_format(source, "%*sreturn (int)(buf - base_buf);\n", INDENT, ""); + wt_format(source, "}\n"); + } + + { + wt_format( + source, + "\nint msg_%s_deserialize(const byte *buf, size_t len, %.*sMessage *msg) {\n", + name, + msgs.name.len, + msgs.name.ptr + ); + + wt_format(source, "%*sconst byte *base_buf = buf;\n", INDENT, ""); + wt_format(source, "%*sif(len < 2 * MSG_MAGIC_SIZE)\n", INDENT, ""); + wt_format(source, "%*sreturn -1;\n", INDENT * 2, ""); + wt_format(source, "%*sif(*(MsgMagic*)buf != MSG_MAGIC_START)\n", INDENT, ""); + wt_format(source, "%*sreturn -1;\n", INDENT * 2, ""); + wt_format(source, "%*sbuf += MSG_MAGIC_SIZE;\n", INDENT, ""); + wt_format(source, "%*s%s tag = *(uint16_t*)buf;\n", INDENT, "", tag_type); + wt_format(source, "%*sswitch(tag) {\n", INDENT, ""); + wt_format(source, "%*scase %sNone:\n%*sbreak;\n", INDENT, "", tag_type, INDENT * 2, ""); + + for (size_t j = 0; j < msgs.messages.len; j++) { + MessageObject m = msgs.messages.data[j]; + TypeObject *mtype = message_tos.data[j]; + Layout *layout = hashmap_get(p->layouts, &(Layout){.type = mtype}); + assert(layout != NULL, "What ?"); + char *snake_case_name = pascal_to_snake_case(m.name); + char *base = msprintf("msg->%s", snake_case_name); + + wt_format(source, "%*scase %s%.*s: {\n", INDENT, "", tag_type, m.name.len, m.name.ptr); + wt_format(source, "%*smsg->tag = %s%.*s;\n", INDENT * 2, "", tag_type, m.name.len, m.name.ptr); + write_type_deserialization( + source, + base, + false, + layout, + (CurrentAlignment){.align = ALIGN_8, .offset = 2}, + p->layouts, + INDENT * 2, + 0, + false + ); + if (m.attributes & Attr_versioned) { + wt_format(source, "%*sif(msg->%s._version != %luUL) {\n", INDENT * 2, "", snake_case_name, msgs.version); + wt_format(source, "%*sprintf(\"Mismatched version: peers aren't the same version", INDENT * 3, ""); + wt_format(source, ", expected %lu got %%lu.\\n\", msg->%s._version);\n", msgs.version, snake_case_name); + wt_format(source, "%*smsg_%s_free(msg);\n", INDENT * 3, "", name); + wt_format(source, "%*sreturn -1;\n", INDENT * 3, ""); + wt_format(source, "%*s}\n", INDENT * 2, ""); + } + wt_format(source, "%*sbreak;\n%*s}\n", INDENT * 2, "", INDENT, ""); + + free(base); + free(snake_case_name); + } + wt_format(source, "%*s}\n", INDENT, ""); + wt_format(source, "%*sif(*(MsgMagic*)buf != MSG_MAGIC_END) {\n", INDENT, ""); + wt_format(source, "%*smsg_%s_free(msg);\n", INDENT * 2, "", name); + wt_format(source, "%*sreturn -1;\n", INDENT * 2, ""); + wt_format(source, "%*s}\n", INDENT, ""); + wt_format(source, "%*sbuf += MSG_MAGIC_SIZE;\n", INDENT, ""); + wt_format(source, "%*sif(buf > base_buf + len) {\n", INDENT, ""); + wt_format(source, "%*smsg_%s_free(msg);\n", INDENT * 2, "", name); + wt_format(source, "%*sreturn -1;\n", INDENT * 2, ""); + wt_format(source, "%*s}\n", INDENT, ""); + wt_format(source, "%*sreturn (int)(buf - base_buf);\n", INDENT, ""); + wt_format(source, "}\n"); + } + + { + wt_format(source, "\nvoid msg_%s_free(%.*sMessage *msg) {\n", name, msgs.name.len, msgs.name.ptr); + + wt_format(source, "%*sswitch(msg->tag) {\n", INDENT, ""); + wt_format(source, "%*scase %sNone:\n%*sbreak;\n", INDENT, "", tag_type, INDENT * 2, ""); + + for (size_t j = 0; j < msgs.messages.len; j++) { + MessageObject m = msgs.messages.data[j]; + TypeObject *mtype = message_tos.data[j]; + + char *snake_case_name = pascal_to_snake_case(m.name); + char *base = msprintf("msg->%s", snake_case_name); + + wt_format(source, "%*scase %s%.*s: {\n", INDENT, "", tag_type, m.name.len, m.name.ptr); + write_type_free(source, base, mtype, p->layouts, INDENT * 2, 0, false); + wt_format(source, "%*sbreak;\n%*s}\n", INDENT * 2, "", INDENT, ""); + + free(base); + free(snake_case_name); + } + wt_format(source, "%*s}\n", INDENT, ""); + wt_format(source, "}\n"); + } + + for (size_t j = 0; j < message_tos.len; j++) { + TypeObject *to = message_tos.data[j]; + StructObject *s = (StructObject *)&to->type.struct_; + + vec_drop(s->fields); + free(to); + } + + vec_drop(message_tos); + + free(tag_type); + free(name); + } + + wt_format(header, "#endif\n"); +} + +typedef struct { + uint16_t field_count; + char *a[4]; +} AA; diff --git a/ser/codegen_c.h b/ser/codegen_c.h new file mode 100644 index 0000000..25ec215 --- /dev/null +++ b/ser/codegen_c.h @@ -0,0 +1,8 @@ +#ifndef CODEGEN_C_H +#define CODEGEN_C_H + +#include "codegen.h" + +void codegen_c(Writer *header, Writer *source, const char *name, Program *p); + +#endif diff --git a/ser/codegen_python.c b/ser/codegen_python.c new file mode 100644 index 0000000..1223b49 --- /dev/null +++ b/ser/codegen_python.c @@ -0,0 +1,537 @@ +#include "codegen_python.h" + +#include + +static void write_type_name(Writer *w, TypeObject *type, Hashmap *defined) { + if (type->kind == TypePrimitif) { +#define _case(x, str) \ + case Primitif_##x: \ + wt_format(w, str); \ + return + switch (type->type.primitif) { + _case(u8, "int"); + _case(u16, "int"); + _case(u32, "int"); + _case(u64, "int"); + _case(i8, "int"); + _case(i16, "int"); + _case(i32, "int"); + _case(i64, "int"); + _case(f32, "float"); + _case(f64, "float"); + _case(bool, "bool"); + _case(char, "str"); + } +#undef _case + } else if (type->kind == TypeArray) { + // If we have an array of char + if (type->type.array.type->kind == TypePrimitif && type->type.array.type->type.primitif == Primitif_char) { + wt_format(w, "str"); + return; + } + wt_format(w, "List["); + write_type_name(w, type->type.array.type, defined); + wt_format(w, "]"); + } else if (type->kind == TypeStruct) { + StructObject *obj = (StructObject *)&type->type.struct_; + if (defined == NULL || hashmap_has(defined, &obj)) { + wt_write(w, obj->name.ptr, obj->name.len); + } else { + wt_format(w, "'%.*s'", obj->name.len, obj->name.ptr); + } + } +} + +typedef enum { + Read, + Write, +} Access; + +static void write_field_accessor(Writer *w, const char *base, FieldAccessor fa, TypeObject *type, Access access) { + if (fa.indices.len == 0) { + wt_format(w, "%s", base); + return; + } + + BufferedWriter b = buffered_writer_init(); + Writer *w2 = (Writer *)&b; + + bool len = false; + wt_format(w2, "%s", base); + for (size_t i = 0; i < fa.indices.len; i++) { + uint64_t index = fa.indices.data[i]; + + if (type->kind == TypeStruct) { + StructObject *obj = (StructObject *)&type->type.struct_; + Field f = obj->fields.data[index]; + wt_format(w2, ".%.*s", f.name.len, f.name.ptr); + + type = f.type; + } else if (type->kind == TypeArray) { + if (type->type.array.sizing == SizingMax) { + len = index == 0; + break; + } else { + wt_format(w2, "[%lu]", index); + type = type->type.array.type; + } + } + } + + if (len) { + if (access == Read) { + wt_format(w, "len(%.*s)", b.buf.len, b.buf.data); + } else { + CharVec buf = b.buf; + for (size_t i = 0; i < buf.len; i++) { + if (buf.data[i] == '.') { + buf.data[i] = '_'; + } + } + wt_write(w, buf.data, buf.len); + wt_format(w, "_len"); + } + } else { + wt_write(w, b.buf.data, b.buf.len); + } + + buffered_writer_drop(b); +} + +static bool field_accessor_is_array_length(FieldAccessor fa, TypeObject *type) { + if (fa.indices.len == 0 || fa.indices.data[fa.indices.len - 1] != 0 || fa.type->kind != TypePrimitif) + return false; + + for (size_t i = 0; i < fa.indices.len - 1; i++) { + uint64_t index = fa.indices.data[i]; + if (type->kind == TypeStruct) { + StructObject *s = (StructObject *)&type->type.struct_; + + type = s->fields.data[index].type; + } else if (type->kind == TypeArray) { + type = type->type.array.type; + } + } + + return type->kind == TypeArray && type->type.array.sizing == SizingMax; +} + +// Find the index of the FieldAccessor that points to the length of this array +static uint64_t get_array_length_field_index(Layout *layout, FieldAccessor fa) { + size_t index = SIZE_MAX; + for (size_t j = 0; j < layout->fields.len && layout->fields.data[j].size != 0; j++) { + FieldAccessor f = layout->fields.data[j]; + if (f.indices.len != fa.indices.len) + continue; + + // Check the indices for equality, all but the last one should be equal + bool equal = true; + for (size_t k = 0; k < f.indices.len - 1; k++) { + if (f.indices.data[k] != fa.indices.data[k]) { + equal = false; + break; + } + } + + if (equal && f.indices.data[f.indices.len - 1] == 0) { + index = j; + break; + } + } + + if (index == SIZE_MAX) { + log_error("No length accessor for variable size array accessor"); + exit(1); + } + + return index; +} + +static void write_type_uninit(Writer *u, TypeObject *type) { + if (type->kind == TypeStruct) { + StructObject *s = (StructObject *)&type->type.struct_; + for (size_t i = 0; i < s->fields.len; i++) { + Field f = s->fields.data[i]; + if (f.type->kind == TypeStruct) { + StringSlice tname = f.type->type.struct_.name; + wt_format(u, "%*sself.%.*s = %.*s.uninit()\n", INDENT * 2, "", f.name.len, f.name.ptr, tname.len, tname.ptr); + } else if (f.type->kind == TypeArray) { + wt_format(u, "%*sself.%.*s = []\n", INDENT * 2, "", f.name.len, f.name.ptr); + } + } + } +} + +static void write_type_funcs( + Writer *s, + Writer *d, + TypeObject *type, + const char *base, + CurrentAlignment al, + Hashmap *layouts, + size_t indent, + size_t depth, + bool always_inline +) { + Layout *layout = hashmap_get(layouts, &(Layout){.type = type}); + assert(layout != NULL, "Type has no layout"); + + if (layout->fields.len == 0) + return; + + Alignment align = al.align; + size_t offset = calign_to(al, layout->fields.data[0].type->align); + + if (offset > 0) { + wt_format(s, "%*sbuf += bytes(%lu)\n", indent, "", offset); + wt_format(d, "%*soff += %lu\n", indent, "", offset); + } + + if (type->kind == TypeStruct && !always_inline) { + offset = calign_to(al, layout->type->align); + if (offset != 0) { + wt_format(s, "%*sbuf += bytes(%lu)\n", indent, "", offset); + wt_format(d, "%*soff = %lu\n", indent, "", offset); + } + wt_format(s, "%*s%s.serialize(buf)\n", indent, "", base); + wt_format(d, "%*soff += %s.deserialize(buf[off:])\n", indent, "", base); + return; + } + + wt_format(s, "%*sbuf += pack('<", indent, ""); + wt_format(d, "%*sxs%lu = unpack('<", indent, "", depth); + al = calign_add(al, offset); + + size_t size = 0; + size_t i = 0; + for (; i < layout->fields.len && layout->fields.data[i].size != 0; i++) { + FieldAccessor fa = layout->fields.data[i]; + assert(fa.type->kind == TypePrimitif, "Field accessor of non zero size doesn't point to primitive type"); + +#define _case(x, f) \ + case Primitif_##x: \ + wt_format(s, f); \ + wt_format(d, f); \ + break + switch (fa.type->type.primitif) { + _case(u8, "B"); + _case(u16, "H"); + _case(u32, "I"); + _case(u64, "Q"); + _case(i8, "b"); + _case(i16, "h"); + _case(i32, "i"); + _case(i64, "q"); + _case(f32, "f"); + _case(f64, "d"); + _case(bool, "?"); + _case(char, "c"); + } +#undef _case + al = calign_add(al, fa.size); + offset += fa.size; + size += fa.size; + } + + size_t padding = 0; + if (i < layout->fields.len) { + padding = calign_to(al, layout->fields.data[i].type->align); + wt_format(s, "%lux", padding); + } + + wt_format(s, "',\n"); + wt_format(d, "', buf[off:off + %lu])\n", size); + + for (size_t j = 0; j < i; j++) { + FieldAccessor fa = layout->fields.data[j]; + + wt_format(s, "%*s ", indent, ""); + write_field_accessor(s, base, fa, type, Read); + if (fa.type->kind == TypePrimitif && fa.type->type.primitif == Primitif_char) { + wt_format(s, ".encode(encoding='ASCII', errors='replace')"); + } + if (j < i - 1) { + wt_format(s, ",\n"); + } else { + wt_format(s, ")\n"); + } + + if (!field_accessor_is_array_length(fa, type)) { + wt_format(d, "%*s", indent, ""); + write_field_accessor(d, base, fa, type, Write); + wt_format(d, " = xs%lu[%lu]", depth, j); + if (fa.type->kind == TypePrimitif && fa.type->type.primitif == Primitif_char) { + wt_format(d, ".decode(encoding='ASCII', errors='replace')"); + } + wt_format(d, "\n"); + } + } + + if (i < layout->fields.len) { + wt_format(d, "%*soff += %lu\n", indent, "", padding + size); + } else { + wt_format(d, "%*soff += %lu\n", indent, "", padding + size + calign_to(al, align)); + } + + bool alignment_unknown = false; + for (; i < layout->fields.len; i++) { + alignment_unknown = true; + + FieldAccessor fa = layout->fields.data[i]; + uint64_t len_index = get_array_length_field_index(layout, fa); + + if (fa.type->kind == TypePrimitif && fa.type->type.primitif == Primitif_char) { + wt_format(s, "%*sbuf += ", indent, ""); + wt_format(d, "%*s", indent, ""); + write_field_accessor(s, base, fa, type, Read); + write_field_accessor(d, base, fa, type, Write); + + wt_format(d, " = buf[off:off + xs%lu[%lu]].decode(encoding='ASCII', errors='replace')\n", depth, len_index); + wt_format(d, "%*soff += xs%lu[%lu]\n", indent, "", depth, len_index); + wt_format(s, ".encode(encoding='ASCII', errors='replace')\n"); + continue; + } + + wt_format(s, "%*sfor e%lu in ", indent, "", depth); + write_field_accessor(s, base, fa, type, Read); + wt_format(s, ":\n"); + wt_format(d, "%*s", indent, ""); + write_field_accessor(d, base, fa, type, Write); + wt_format(d, " = []\n"); + wt_format(d, "%*sfor _ in range(xs%lu[%lu]):\n", indent, "", depth, len_index); + if (fa.type->kind == TypeArray && fa.type->type.array.sizing == SizingFixed) { + wt_format(d, "%*se%lu = [None] * %lu\n", indent + INDENT, "", depth, fa.type->type.array.size); + } else if (fa.type->kind == TypeStruct) { + struct StructObject s = fa.type->type.struct_; + wt_format(d, "%*se%lu = %.*s.uninit()\n", indent + INDENT, "", depth, s.name.len, s.name.ptr); + } + char *new_base = msprintf("e%lu", depth); + write_type_funcs( + s, + d, + fa.type, + new_base, + (CurrentAlignment){.align = fa.type->align, .offset = 0}, + layouts, + indent + INDENT, + depth + 1, + false + ); + wt_format(d, "%*s", indent + INDENT, ""); + write_field_accessor(d, base, fa, type, Write); + wt_format(d, ".append(%s)\n", new_base); + free(new_base); + } + + if (alignment_unknown) { + wt_format(s, "%*sbuf += bytes((%u - len(buf)) & %u)\n", indent, "", align.value, align.mask); + wt_format(d, "%*soff += (%u - off) & %u\n", indent, "", align.value, align.mask); + } +} + +static void write_struct_class(Writer *w, StructObject *obj, Hashmap *defined, Hashmap *layouts) { + TypeObject *type = (void *)((byte *)obj - offsetof(TypeObject, type)); + + wt_format(w, "@dataclass\n"); + wt_format(w, "class %.*s:\n", obj->name.len, obj->name.ptr); + for (size_t i = 0; i < obj->fields.len; i++) { + Field f = obj->fields.data[i]; + wt_format(w, "%*s%.*s: ", INDENT, "", f.name.len, f.name.ptr); + write_type_name(w, f.type, defined); + wt_format(w, "\n"); + } + BufferedWriter ser = buffered_writer_init(); + BufferedWriter deser = buffered_writer_init(); + write_type_funcs( + (Writer *)&ser, + (Writer *)&deser, + type, + "self", + (CurrentAlignment){.align = type->align, .offset = 0}, + layouts, + INDENT * 2, + 0, + true + ); + + wt_format(w, "%*s\n", INDENT, ""); + wt_format(w, "%*s@classmethod\n", INDENT, ""); + wt_format(w, "%*sdef uninit(cls) -> '%.*s':\n", INDENT, "", obj->name.len, obj->name.ptr); + wt_format(w, "%*sself = cls.__new__(cls)\n", INDENT * 2, ""); + write_type_uninit(w, type); + wt_format(w, "%*sreturn self\n", INDENT * 2, ""); + wt_format(w, "%*s\n", INDENT, ""); + wt_format(w, "%*sdef serialize(self, buf: bytearray):\n", INDENT, ""); + wt_format(w, "%*sbase = len(buf)\n", INDENT * 2, ""); + wt_write(w, ser.buf.data, ser.buf.len); + wt_format(w, "%*sreturn len(buf) - base\n", INDENT * 2, ""); + wt_format(w, "%*s\n", INDENT, ""); + wt_format(w, "%*sdef deserialize(self, buf: bytes):\n", INDENT, ""); + wt_format(w, "%*soff = 0\n", INDENT * 2, ""); + wt_write(w, deser.buf.data, deser.buf.len); + wt_format(w, "%*sreturn off\n", INDENT * 2, ""); + wt_format(w, "\n"); + + buffered_writer_drop(ser); + buffered_writer_drop(deser); + + hashmap_set(defined, &obj); +} + +typedef struct { + Hashmap *layouts; + Hashmap *defined; +} CallbackData; + +static void write_struct(Writer *w, StructObject *obj, void *user_data) { + CallbackData *data = user_data; + write_struct_class(w, obj, data->defined, data->layouts); +} + +static void define_struct_classes(Writer *w, Program *p) { + Hashmap *defined = hashmap_init(pointer_hash, pointer_equal, NULL, sizeof(StructObject *)); + CallbackData data = {.defined = defined, .layouts = p->layouts}; + define_structs(p, w, write_struct, &data); + hashmap_drop(defined); +} + +static void define_message(Writer *w, const char *prefix, uint16_t tag, Hashmap *layouts, MessageObject msg, uint64_t version) { + char *name = msprintf("%s%.*s", prefix, msg.name.len, msg.name.ptr); + StringSlice name_slice = {.ptr = name, .len = strlen(name)}; + + wt_format(w, "@dataclass\n"); + wt_format(w, "class %s(%sMessage):\n", name, prefix); + + TypeObject *type; + StructObject *obj; + FieldVec fields = vec_clone(&msg.fields); + { + if (msg.attributes & Attr_versioned) { + Field f = {.name = STRING_SLICE("_version"), .type = (TypeObject *)&PRIMITIF_u64}; + vec_push(&fields, f); + } + + type = malloc(sizeof(TypeObject)); + assert_alloc(type); + type->kind = TypeStruct; + type->type.struct_.name = name_slice; + type->type.struct_.has_funcs = false; + type->type.struct_.fields = *(AnyVec *)&fields; + type->align = ALIGN_8; + obj = (StructObject *)&type->type.struct_; + + Layout l = type_layout(type); + hashmap_set(layouts, &l); + } + + for (size_t i = 0; i < msg.fields.len; i++) { + Field f = msg.fields.data[i]; + wt_format(w, "%*s%.*s: ", INDENT, "", f.name.len, f.name.ptr); + write_type_name(w, f.type, NULL); + wt_format(w, "\n"); + } + + if (msg.attributes & Attr_versioned) { + wt_format(w, "%*s_version: int = %lu\n", INDENT, "", version); + } + + BufferedWriter ser = buffered_writer_init(); + BufferedWriter deser = buffered_writer_init(); + write_type_funcs( + (Writer *)&ser, + (Writer *)&deser, + type, + "self", + (CurrentAlignment){.align = type->align, .offset = 2}, + layouts, + INDENT * 2, + 0, + true + ); + + wt_format(w, "%*s\n", INDENT, ""); + wt_format(w, "%*s@classmethod\n", INDENT, ""); + wt_format(w, "%*sdef uninit(cls) -> '%s':\n", INDENT, "", name); + wt_format(w, "%*sself = cls.__new__(cls)\n", INDENT * 2, ""); + write_type_uninit(w, type); + wt_format(w, "%*sreturn self\n", INDENT * 2, ""); + wt_format(w, "%*s\n", INDENT, ""); + wt_format(w, "%*sdef serialize(self, buf: bytearray):\n", INDENT, ""); + wt_format(w, "%*sbase = len(buf)\n", INDENT * 2, ""); + wt_format(w, "%*sbuf += pack('>QH', MSG_MAGIC_START, %u)\n", INDENT * 2, "", tag); + wt_write(w, ser.buf.data, ser.buf.len); + wt_format(w, "%*sbuf += pack('>Q', MSG_MAGIC_END)\n", INDENT * 2, ""); + wt_format(w, "%*sreturn len(buf) - base\n", INDENT * 2, ""); + wt_format(w, "%*s\n", INDENT, ""); + wt_format(w, "%*s@classmethod\n", INDENT, ""); + wt_format(w, "%*sdef _deserialize(cls, buf: bytes) -> Tuple['%s', int]:\n", INDENT, "", name); + wt_format(w, "%*smagic_start, tag = unpack('>QH', buf[0:10])\n", INDENT * 2, ""); + wt_format(w, "%*sif magic_start != MSG_MAGIC_START or tag != %u:\n", INDENT * 2, "", tag); + wt_format(w, "%*sraise ValueError\n", INDENT * 3, ""); + wt_format(w, "%*soff = 10\n", INDENT * 2, ""); + wt_format(w, "%*sself = %s.uninit()\n", INDENT * 2, "", name); + wt_write(w, deser.buf.data, deser.buf.len); + wt_format(w, "%*smagic_end = unpack('>Q', buf[off:off + 8])[0]\n", INDENT * 2, ""); + wt_format(w, "%*sif magic_end != MSG_MAGIC_END:\n", INDENT * 2, ""); + wt_format(w, "%*sraise ValueError\n", INDENT * 3, ""); + wt_format(w, "%*soff += 8\n", INDENT * 2, ""); + wt_format(w, "%*sreturn self, off\n", INDENT * 2, ""); + wt_format(w, "\n"); + + buffered_writer_drop(ser); + buffered_writer_drop(deser); + free(name); + vec_drop(fields); + free(type); +} + +static void define_messages(Writer *w, MessagesObject msgs, Program *p) { + char *prefix = strndup(msgs.name.ptr, msgs.name.len); + wt_format(w, "class %sMessage(ABC):\n", prefix); + wt_format(w, "%*s@abstractmethod\n", INDENT, ""); + wt_format(w, "%*sdef serialize(self, buf: bytearray) -> int:\n", INDENT, ""); + wt_format(w, "%*spass\n", INDENT * 2, ""); + wt_format(w, "%*s@classmethod\n", INDENT, ""); + wt_format(w, "%*sdef deserialize(cls, buf: bytes) -> Tuple['Message', int]:\n", INDENT, ""); + wt_format(w, "%*smagic_start, tag = unpack('>QH', buf[0:10])\n", INDENT * 2, ""); + wt_format(w, "%*sif magic_start != MSG_MAGIC_START:\n", INDENT * 2, ""); + wt_format(w, "%*sraise ValueError\n", INDENT * 3, ""); + for (size_t i = 0; i < msgs.messages.len; i++) { + if (i == 0) { + wt_format(w, "%*sif tag == 0:\n", INDENT * 2, ""); + } else { + wt_format(w, "%*selif tag == %lu:\n", INDENT * 2, "", i); + } + StringSlice name = msgs.messages.data[i].name; + wt_format(w, "%*sreturn %s%.*s._deserialize(buf)\n", INDENT * 3, "", prefix, name.len, name.ptr); + } + wt_format(w, "%*selse:\n", INDENT * 2, ""); + wt_format(w, "%*sraise ValueError\n", INDENT * 3, ""); + + for (size_t i = 0; i < msgs.messages.len; i++) { + define_message(w, prefix, i, p->layouts, msgs.messages.data[i], msgs.version); + } + free(prefix); +} + +void codegen_python(Writer *source, Program *p) { + wt_format( + source, + "# generated file\n" + "from dataclasses import dataclass\n" + "from typing import List, Tuple\n" + "from abc import ABC, abstractmethod\n" + "from struct import pack, unpack\n" + "\n" + "MSG_MAGIC_START = 0x%016lX\n" + "MSG_MAGIC_END = 0x%016lX\n" + "\n", + MSG_MAGIC_START, + MSG_MAGIC_END + ); + + define_struct_classes(source, p); + for (size_t i = 0; i < p->messages.len; i++) { + define_messages(source, p->messages.data[i], p); + } +} diff --git a/ser/codegen_python.h b/ser/codegen_python.h new file mode 100644 index 0000000..8067070 --- /dev/null +++ b/ser/codegen_python.h @@ -0,0 +1,8 @@ +#ifndef CODEGEN_PYTHON_H +#define CODEGEN_PYTHON_H + +#include "codegen.h" + +void codegen_python(Writer *source, Program *p); + +#endif diff --git a/ser/eval.c b/ser/eval.c new file mode 100644 index 0000000..11e8201 --- /dev/null +++ b/ser/eval.c @@ -0,0 +1,1318 @@ +#include "eval.h" + +#include "ast.h" +#include "gen_vec.h" +#include "hashmap.h" +#include "vector.h" +#include "vector_impl.h" + +#include +#include + +#define PRIMITIF_TO(name, al) \ + const TypeObject PRIMITIF_##name = {.kind = TypePrimitif, .align = _ALIGN_##al, .type.primitif = Primitif_##name} +PRIMITIF_TO(u8, 1); +PRIMITIF_TO(u16, 2); +PRIMITIF_TO(u32, 4); +PRIMITIF_TO(u64, 8); +PRIMITIF_TO(i8, 1); +PRIMITIF_TO(i16, 2); +PRIMITIF_TO(i32, 4); +PRIMITIF_TO(i64, 8); +PRIMITIF_TO(f32, 4); +PRIMITIF_TO(f64, 8); +PRIMITIF_TO(char, 1); +PRIMITIF_TO(bool, 1); +#undef PRIMITIF_TO + +void array_drop(Array a) { free(a.type); } + +void type_drop(TypeObject t) { + if (t.kind == TypeArray) { + array_drop(t.type.array); + } +} + +void struct_drop(StructObject s) { vec_drop(s.fields); } + +void message_drop(MessageObject m) { vec_drop(m.fields); } + +void messages_drop(MessagesObject m) { vec_drop(m.messages); } + +static Alignment max_alignment(Alignment a, Alignment b) { + if (a.value > b.value) { + return a; + } else { + return b; + } +} + +static char *attributes_to_string(Attributes attrs, bool and) { + uint32_t count = 0; + Attributes attributes[ATTRIBUTES_COUNT]; +#define handle(a) \ + if (attrs & Attr_##a) \ + attributes[count++] = Attr_##a; + handle(versioned); +#undef handle + CharVec res = vec_init(); + for (size_t i = 0; i < count; i++) { + if (i == 0) { + } else if (i < count - 1) { + vec_push_array(&res, ", ", 2); + } else if (and) { + vec_push_array(&res, " and ", 5); + } else { + vec_push_array(&res, " or ", 4); + } + +#define _case(x) \ + case Attr_##x: \ + vec_push_array(&res, #x, sizeof(#x) - 1); \ + break + switch (attributes[i]) { + _case(versioned); + default: + vec_push_array(&res, "(invalid attribute)", 19); + break; + } +#undef _case + } + vec_push(&res, '\0'); + return res.data; +} + +static inline EvalError err_duplicate_def(Span first, Span second, AstTag type, StringSlice ident) { + return (EvalError){ + .dup = {.tag = EETDuplicateDefinition, .first = first, .second = second, .type = type, .ident = ident} + }; +} + +static inline EvalError err_unknown(Span span, AstTag type, StringSlice ident) { + return (EvalError){ + .unk = {.tag = EETUnknown, .span = span, .type = type, .ident = ident} + }; +} + +static inline EvalError err_empty(Span span, AstTag type, StringSlice ident) { + return (EvalError){ + .empty = {.tag = EETEmptyType, .span = span, .type = type, .ident = ident} + }; +} + +void eval_error_report(Source *src, EvalError *err) { + switch (err->tag) { + case EETUnknown: { + EvalErrorUnknown unk = err->unk; + ReportSpan span = {.span = unk.span, .sev = ReportSeverityError}; + const char *type = ""; + switch (unk.type) { + case ATConstant: + type = "constant"; + break; + case ATAttribute: + type = "attribute"; + break; + default: + type = "identifier"; + break; + } + char *help = NULL; + if (unk.type == ATAttribute) { + char *attributes = attributes_to_string(~0, false); + help = msprintf("expected %s", attributes); + free(attributes); + } + source_report( + src, + unk.span.loc, + ReportSeverityError, + &span, + 1, + help, + "Unknown %s '%.*s'", + type, + unk.ident.len, + unk.ident.ptr + ); + if (help != NULL) { + free(help); + } + break; + } + case EETDuplicateDefinition: { + EvalErrorDuplicateDefinition dup = err->dup; + ReportSpan spans[] = { + {.span = dup.first, + .sev = ReportSeverityNote, + .message = msprintf("first definition of '%.*s' here", dup.ident.len, dup.ident.ptr)}, + {.span = dup.second, .sev = ReportSeverityError, .message = "redefined here"} + }; + const char *type = ""; + switch (dup.type) { + case ATConstant: + type = "constant"; + break; + case ATIdent: + type = "identifier"; + break; + case ATField: + type = "field"; + break; + default: + break; + } + source_report( + src, + dup.second.loc, + ReportSeverityError, + spans, + 2, + NULL, + "Duplicate definition of %s '%.*s'", + type, + dup.ident.len, + dup.ident.ptr + ); + free((char *)spans[0].message); + break; + } + case EETCycle: { + EvalErrorCycle cycle = err->cycle; + const char *type = ""; + if (cycle.type == ATConstant) { + type = "constant"; + } else if (cycle.type == ATTypeDecl) { + type = "type declaration"; + } + // Check if the spans are ordered + bool spans_ascending = true; + bool spans_descending = true; + for (size_t i = 1; i < cycle.spans.len; i++) { + int comp = span_compare(&cycle.spans.data[i - 1], &cycle.spans.data[i]); + spans_ascending = spans_ascending && comp >= 0; + spans_descending = spans_descending && comp <= 0; + if (!spans_ascending && !spans_descending) + break; + } + bool ordered = spans_ascending | spans_descending; + if (ordered) { + // If they are, we can print the info on a span each (less noisy output) + ReportSpanVec spans = vec_init(); + vec_grow(&spans, cycle.spans.len); + + for (size_t i = 0; i < cycle.spans.len; i++) { + ReportSeverity sev; + char *message; + StringSlice name = cycle.idents.data[i]; + StringSlice next_name = cycle.idents.data[(i + 1) % cycle.idents.len]; + + if (cycle.spans.len == 1) { // Special case for a constant equal to itself + sev = ReportSeverityError; + message = msprintf("%.*s requires evaluating itself", name.len, name.ptr); + } else if (i == 0) { // First equality + sev = ReportSeverityError; + message = msprintf("%.*s requires evaluating %.*s", name.len, name.ptr, next_name.len, next_name.ptr); + } else if (i < cycle.spans.len - 1) { + sev = ReportSeverityNote; + message = msprintf("... which requires %.*s ...", next_name.len, next_name.ptr); + } else { // Looparound + sev = ReportSeverityNote; + message = msprintf("... which again requires %.*s", next_name.len, next_name.ptr); + } + + vec_push(&spans, ((ReportSpan){.span = cycle.spans.data[i], .sev = sev, .message = message})); + } + + source_report( + src, + cycle.spans.data[0].loc, + ReportSeverityError, + spans.data, + spans.len, + NULL, + "cycle detected when evaluating %s '%.*s'", + type, + cycle.idents.data[0].len, + cycle.idents.data[0].ptr + ); + + for (size_t i = 0; i < spans.len; i++) { + free((char *)spans.data[i].message); + } + vec_drop(spans); + } else { + // If they aren't we have to use a report per span (because the lines are not ordered) + ReportSpan span; + + span.span = cycle.spans.data[0]; + span.sev = ReportSeverityError; + if (cycle.spans.len >= 2) { + span.message = NULL; + } else { + span.message = msprintf("%.*s requires evaluating itself", cycle.idents.data[0].len, cycle.idents.data[0].ptr); + } + + source_report( + src, + cycle.spans.data[0].loc, + ReportSeverityError, + &span, + 1, + NULL, + "cycle detected when evaluating %s '%.*s'", + type, + cycle.idents.data[0].len, + cycle.idents.data[0].ptr + ); + + if (span.message != NULL) { + free((char *)span.message); + } + + span.sev = ReportSeverityNote; + span.message = NULL; + for (size_t i = 1; i < cycle.idents.len; i++) { + span.span = cycle.spans.data[i]; + StringSlice name = cycle.idents.data[i]; + + if (i == cycle.idents.len - 1) { + span.message = + msprintf("which again requires evaluating %.*s", cycle.idents.data[0].len, cycle.idents.data[0].ptr); + } + source_report( + src, + span.span.loc, + ReportSeverityNote, + &span, + 1, + NULL, + "... which requires evaluating %.*s ...", + name.len, + name.ptr + ); + } + free((char *)span.message); + } + break; + } + case EETInfiniteStruct: { + EvalErrorInfiniteStruct infs = err->infs; + ReportSpanVec spans = vec_init(); + CharVec structs = vec_init(); + vec_grow(&spans, infs.fields.len + infs.structs.len); + + vec_push(&structs, '\''); + SpannedStringSlice last_struct = infs.structs.data[infs.structs.len - 1]; + vec_push_array(&structs, last_struct.slice.ptr, last_struct.span.len); + vec_push(&structs, '\''); + for (int i = infs.structs.len - 2; i >= 0; i--) { + if (i == 1) { + vec_push_array(&structs, " and ", 5); + } else { + vec_push_array(&structs, ", ", 2); + } + SpannedStringSlice s = infs.structs.data[i]; + vec_push(&structs, '\''); + vec_push_array(&structs, s.slice.ptr, s.slice.len); + vec_push(&structs, '\''); + } + vec_push(&structs, '\0'); + + for (int i = infs.structs.len - 1; i >= 0; i--) { + ReportSpan span[] = { + {.sev = ReportSeverityError, .message = NULL, .span = infs.structs.data[i].span}, + {.sev = ReportSeverityNote, .message = "recursive without limit", .span = infs.fields.data[i].span } + }; + vec_push_array(&spans, span, 2); + } + + source_report( + src, + infs.structs.data[0].span.loc, + ReportSeverityError, + spans.data, + spans.len, + "insert some limiting indirection ('[]', '&[]', or '&[^max size]') to break the cycle", + "recursive struct%s %s ha%s infinite size", + infs.structs.len > 1 ? "s" : "", + structs.data, + infs.structs.len > 1 ? "ve" : "s" + ); + vec_drop(spans); + vec_drop(structs); + break; + } + case EETEmptyType: { + EvalErrorEmptyType empty = err->empty; + char *type = ""; + if (empty.type == ATStruct) { + type = "struct"; + } else if (empty.type == ATMessage) { + type = "message"; + } + + ReportSpan span = {.span = empty.span, .sev = ReportSeverityError, .message = "zero sized types aren't allowed"}; + + source_report( + src, + empty.span.loc, + ReportSeverityError, + &span, + 1, + NULL, + "%s '%.*s' doesn't have any field", + type, + empty.ident.len, + empty.ident.ptr + ); + break; + } + } + fprintf(stderr, "\n"); +} + +void eval_error_drop(EvalError err) { + switch (err.tag) { + case EETCycle: + vec_drop(err.cycle.idents); + vec_drop(err.cycle.spans); + break; + case EETInfiniteStruct: + vec_drop(err.infs.structs); + vec_drop(err.infs.fields); + default: + break; + } +} + +static inline StringSlice string_slice_from_token(Token t) { return (StringSlice){.ptr = t.lexeme, .len = t.span.len}; } + +static SpannedStringSlice sss_from_token(Token t) { + return (SpannedStringSlice){.slice.ptr = t.lexeme, .slice.len = t.span.len, .span = t.span}; +} + +typedef struct { + Hashmap *constants; + Hashmap *typedefs; + Hashmap *layouts; + Hashmap *unresolved; + Hashmap *names; + PointerVec type_objects; + AstItemVec *items; + EvalErrorVec errors; + MessagesObjectVec messages; +} EvaluationContext; + +typedef struct { + StringSlice constant; + Token value; + Span name_span; + Span span; +} UnresolvedConstant; + +typedef struct { + StringSlice type; + AstNode value; + Span name_span; + Span span; +} UnresolvedTypeDef; + +typedef struct { + TypeObject *type; + StringSlice name; +} TypeName; + +impl_hashmap_delegate(unconst, UnresolvedConstant, string_slice, constant); +impl_hashmap_delegate(const, Constant, string_slice, name); +impl_hashmap_delegate(untypd, UnresolvedTypeDef, string_slice, type); +impl_hashmap_delegate(typedef, TypeDef, string_slice, name); +impl_hashmap( + layout, Layout, { return hash(state, (byte *)&v->type, sizeof(TypeObject *)); }, { return a->type == b->type; } +); +impl_hashmap( + typename, TypeName, { return hash(state, (byte *)&v->type, sizeof(TypeObject *)); }, { return a->type == b->type; } +); + +static uint64_t get_ast_number_value(EvaluationContext *ctx, AstNumber number) { + if (number.token.type == Number) { + return number.token.lit; + } else { // The token is an Ident + StringSlice ident = string_slice_from_token(number.token); + Constant *c = hashmap_get(ctx->constants, &(Constant){.name = ident}); + if (c != NULL) { + // If the constant is invalid we make up a value to continue checking for errors + // (Since it is invalid there already has been at least one and we know this code + // can't go to the next stage) + return c->valid ? c->value : 0; + } else { + // This constant doesn't exist: raise an error and return dummy value to continue + vec_push(&ctx->errors, err_unknown(number.token.span, ATConstant, ident)); + return 0; + } + } +} + +static Sizing ast_size_to_sizing(EvaluationContext *ctx, AstSize size, uint64_t *value) { + if (size.tag == ATMaxSize) { + *value = get_ast_number_value(ctx, size.value); + return SizingMax; + } else if (size.tag == ATFixedSize) { + *value = get_ast_number_value(ctx, size.value); + return SizingFixed; + } else { + *value = UINT16_MAX; + return SizingMax; + } +} + +static void _type_print(Hashmap *type_set, TypeObject *type) { + if (type == NULL) { + fprintf(stderr, "(invalid)"); + return; + } + + if (hashmap_set(type_set, &type)) { + if (type->kind == TypeStruct) { + fprintf(stderr, "%.*s", type->type.struct_.name.len, type->type.struct_.name.ptr); + } else { + fprintf(stderr, "(recursion)"); + } + return; + }; + + if (type->kind == TypePrimitif) { +#define _case(t) \ + case Primitif_##t: \ + fprintf(stderr, #t); \ + break + switch (type->type.primitif) { + _case(u8); + _case(u16); + _case(u32); + _case(u64); + _case(i8); + _case(i16); + _case(i32); + _case(i64); + _case(f32); + _case(f64); + _case(char); + _case(bool); + } +#undef _case + } else if (type->kind == TypeArray) { + _type_print(type_set, (TypeObject *)type->type.array.type); + if (type->type.array.heap) + fprintf(stderr, "&"); + if (type->type.array.sizing == SizingFixed) + fprintf(stderr, "[%lu]", type->type.array.size); + else if (type->type.array.sizing == SizingMax) + fprintf(stderr, "[^%lu]", type->type.array.size); + else + fprintf(stderr, "[]"); + } else { + StructObject s = *(StructObject *)&type->type.struct_; + fprintf(stderr, "{ "); + for (size_t i = 0; i < s.fields.len; i++) { + fprintf(stderr, "%.*s: ", s.fields.data[i].name.len, s.fields.data[i].name.ptr); + _type_print(type_set, s.fields.data[i].type); + if (i < s.fields.len - 1) { + fprintf(stderr, ", "); + } + } + fprintf(stderr, " }"); + } +} + +__attribute__((unused)) static void type_print(TypeObject *type) { + Hashmap *type_set = hashmap_init(pointer_hash, pointer_equal, NULL, sizeof(TypeObject *)); + _type_print(type_set, type); + hashmap_drop(type_set); +} + +static TypeObject *resolve_type(EvaluationContext *ctx, SpannedStringSlice name); + +static TypeObject *ast_type_to_type_obj(EvaluationContext *ctx, AstType type) { + if (type.tag == ATHeapArray || type.tag == ATFieldArray) { + TypeObject *res = malloc(sizeof(TypeObject)); + assert_alloc(res); + vec_push(&ctx->type_objects, res); + res->kind = TypeArray; + res->type.array.heap = type.tag == ATHeapArray; + res->type.array.sizing = ast_size_to_sizing(ctx, type.array.size, &res->type.array.size); + res->type.array.type = (struct TypeObject *)ast_type_to_type_obj(ctx, *(AstType *)type.array.type); + res->align.value = 0; + return res; + } else { // Otherwise the type is an identifier + return resolve_type(ctx, sss_from_token(type.ident.token)); + } +} + +static TypeObject *resolve_type(EvaluationContext *ctx, SpannedStringSlice name) { + TypeDef *type_def = hashmap_get(ctx->typedefs, &(TypeDef){.name = name.slice}); + if (type_def != NULL) { // Type is already resolved + return type_def->value; + } + + // Type isn't defined anywhere + if (ctx->unresolved == NULL || !hashmap_has(ctx->unresolved, &(UnresolvedTypeDef){.type = name.slice})) { + vec_push(&ctx->errors, err_unknown(name.span, ATIdent, name.slice)); + return NULL; + } + + UnresolvedTypeDef *untd = hashmap_get(ctx->unresolved, &(UnresolvedTypeDef){.type = name.slice}); + + if (untd->value.tag == ATIdent || untd->value.tag == ATFieldArray || untd->value.tag == ATHeapArray) { + hashmap_set(ctx->typedefs, &(TypeDef){.name = name.slice, .value = NULL}); + TypeObject *value = ast_type_to_type_obj(ctx, *(AstType *)&untd->value); + hashmap_set(ctx->typedefs, &(TypeDef){.name = name.slice, .value = value}); + return value; + } else { // Otherwise the value is a struct + AstStruct str = untd->value.struct_; + TypeObject *value = malloc(sizeof(TypeObject)); + { + FieldVec fields = vec_init(); + vec_grow(&fields, str.fields.len); + assert_alloc(value); + vec_push(&ctx->type_objects, value); + value->kind = TypeStruct; + value->type.struct_.fields = *(AnyVec *)&fields; + value->type.struct_.name = name.slice; + value->type.struct_.has_funcs = false; + value->align.value = 0; + hashmap_set(ctx->typedefs, &(TypeDef){.name = name.slice, .value = value}); + } + StructObject *stro = (StructObject *)&value->type.struct_; + + for (size_t i = 0; i < str.fields.len; i++) { + Field f; + f.name = string_slice_from_token(str.fields.data[i].name); + f.name_span = str.fields.data[i].name.span; + f.type = ast_type_to_type_obj(ctx, str.fields.data[i].type); + vec_push(&stro->fields, f); + } + + return value; + } +} + +// Check struct object for direct recursion, returns true if the struct contains a reference to rec somewhere +static bool check_for_recursion( + EvaluationContext *ctx, EvalErrorInfiniteStruct *err, Hashmap *checked, Hashmap *invalids, TypeObject *rec, StructObject *str +) { + // Shortcircuit if we already checked this struct + // This also avoids running into recursion + // (In the case of invalids there already has been an error, so we don't produce another) + if (hashmap_set(checked, &str) || hashmap_has(invalids, &str)) { + return false; + } + + for (size_t i = 0; i < str->fields.len; i++) { + Field f = str->fields.data[i]; + + TypeObject *type = f.type; + if (type == NULL) + continue; + + // Non heap arrays work very much the same as regular fields, Fixed size array as well (with added indirection) + while (type->kind == TypeArray && (!type->type.array.heap || type->type.array.sizing == SizingFixed)) { + type = type->type.array.type; + } + + // Anything else won't recurse: primitives can't, and heap arrays add indirection + // (Field arrays have been eliminated above) + if (type->kind != TypeStruct) { + continue; + } + + // If we got here the type is a struct + StructObject *obj = (StructObject *)&type->type.struct_; + + if (type == rec || check_for_recursion(ctx, err, checked, invalids, rec, obj)) { + // The struct contains rec + + UnresolvedTypeDef *unr = hashmap_get(ctx->unresolved, &(UnresolvedTypeDef){.type = str->name}); + SpannedStringSlice struct_ = {.slice = unr->type, .span = unr->name_span}; + AstField af = unr->value.struct_.fields.data[i]; + // af can be either ATFieldArray or ATIdent + while (af.type.tag == ATFieldArray || af.type.tag == ATHeapArray) { + af.type = *(AstType *)af.type.array.type; + } + SpannedStringSlice field = sss_from_token(af.type.ident.token); + + vec_push(&err->structs, struct_); + vec_push(&err->fields, field); + + hashmap_set(invalids, &str); + + return true; + } + } + + return false; +} + +static Alignment resolve_alignment(TypeObject *type, Hashmap *seen) { + // Check if the type has already been resolved + if (type->align.value != 0) { + return type->align; + } + + // Avoid cycles: if we already have seen this type (but not resolved), no need to check it again + // (since we're computing the max) + if (hashmap_set(seen, &type)) { + return ALIGN_1; + } + + if (type->kind == TypeStruct) { + Alignment res = ALIGN_1; + StructObject *s = (StructObject *)&type->type.struct_; + for (size_t i = 0; i < s->fields.len; i++) { + res = max_alignment(res, resolve_alignment(s->fields.data[i].type, seen)); + } + return res; + } + + // Type is type array (since primitive already have an alignment) + debug_assert(type->kind == TypeArray, ""); + + if (type->type.array.sizing == SizingMax) { + uint64_t size = type->type.array.size; + Alignment res; + if (size <= UINT8_MAX) { + res = ALIGN_1; + } else if (size <= UINT16_MAX) { + res = ALIGN_2; + } else if (size <= UINT32_MAX) { + res = ALIGN_4; + } else { + res = ALIGN_8; + } + res = max_alignment(res, resolve_alignment(type->type.array.type, seen)); + return res; + } + + // Type is fixed size array + return resolve_alignment(type->type.array.type, seen); +} + +void field_accessor_drop(FieldAccessor fa) { vec_drop(fa.indices); } +FieldAccessor field_accessor_clone(FieldAccessor *fa) { + return (FieldAccessor){.type = fa->type, .size = fa->size, .indices = vec_clone(&fa->indices)}; +} + +void layout_drop(void *l) { vec_drop(((Layout *)l)->fields); } + +static void add_fields(FieldAccessorVec *v, TypeObject *t, const uint64_t *base, size_t len) { + if (t->kind == TypePrimitif) { + FieldAccessor fa = {.indices = vec_init()}; +#define _case(typ, n) \ + case Primitif_##typ: \ + fa.size = n; \ + fa.type = (TypeObject *)&PRIMITIF_##typ; \ + break; + switch (t->type.primitif) { + _case(bool, 1); + _case(char, 1); + _case(i8, 1); + _case(u8, 1); + _case(i16, 2); + _case(u16, 2); + _case(i32, 4); + _case(u32, 4); + _case(f32, 4); + _case(i64, 8); + _case(u64, 8); + _case(f64, 8); + } +#undef _case + vec_push_array(&fa.indices, base, len); + vec_push(v, fa); + } else if (t->kind == TypeStruct) { + StructObject *s = (StructObject *)&t->type.struct_; + UInt64Vec new_base = vec_init(); + vec_grow(&new_base, len + 1); + vec_push_array(&new_base, base, len); + vec_push(&new_base, 0); + for (size_t i = 0; i < s->fields.len; i++) { + new_base.data[len] = i; + add_fields(v, s->fields.data[i].type, new_base.data, new_base.len); + } + vec_drop(new_base); + } else { // Type is array + if (t->type.array.sizing == SizingMax) { + FieldAccessor fa = {.indices = vec_init()}; + FieldAccessor fl = {.indices = vec_init()}; + vec_grow(&fa.indices, len + 1); + vec_grow(&fl.indices, len + 1); + vec_push_array(&fa.indices, base, len); + vec_push_array(&fl.indices, base, len); + vec_push(&fa.indices, 1); + vec_push(&fl.indices, 0); + + fa.size = 0; + fa.type = t->type.array.type; + + uint64_t size = t->type.array.size; + if (size <= UINT8_MAX) { + fl.size = 1; + fl.type = (TypeObject *)&PRIMITIF_u8; + } else if (size <= UINT16_MAX) { + fl.size = 2; + fl.type = (TypeObject *)&PRIMITIF_u16; + } else if (size <= UINT32_MAX) { + fl.size = 4; + fl.type = (TypeObject *)&PRIMITIF_u32; + } else { + fl.size = 8; + fl.type = (TypeObject *)&PRIMITIF_u64; + } + + vec_push(v, fa); + vec_push(v, fl); + } else { + UInt64Vec new_base = vec_init(); + vec_grow(&new_base, len + 1); + vec_push_array(&new_base, base, len); + vec_push(&new_base, 0); + for (size_t i = 0; i < t->type.array.size; i++) { + new_base.data[len] = i; + add_fields(v, t->type.array.type, new_base.data, new_base.len); + } + vec_drop(new_base); + } + } +} + +static int fa_compare(const void *a, const void *b) { + const FieldAccessor *fa = (const FieldAccessor *)a; + const FieldAccessor *fb = (const FieldAccessor *)b; + if (fb->size != 0 && fa->size == 0) + return 1; + if (fa->size != 0 && fb->size == 0) + return -1; + return (int)fb->type->align.value - (int)fa->type->align.value; +} + +Layout type_layout(TypeObject *type) { + Layout l = {.type = type, .fields = vec_init()}; + add_fields(&l.fields, type, NULL, 0); + qsort(l.fields.data, l.fields.len, sizeof(FieldAccessor), fa_compare); + return l; +} + +static void resolve_types(EvaluationContext *ctx) { + AstItemVec *items = ctx->items; + Hashmap *untypds = hashmap_init(untypd_hash, untypd_equal, NULL, sizeof(UnresolvedTypeDef)); + + ctx->unresolved = untypds; + + // Get the unresolved type definitions in the map and report duplicates + for (int i = 0; i < items->len; i++) { + if (items->data[i].tag != ATStruct && items->data[i].tag != ATTypeDecl) { + continue; + } + + UnresolvedTypeDef td; + if (items->data[i].tag == ATTypeDecl) { + AstTypeDecl t = items->data[i].type_decl; + td.type = string_slice_from_token(t.name); + td.span = t.span; + td.name_span = t.name.span; + td.value.type = t.value; + } else { + AstStruct s = items->data[i].struct_; + td.type = string_slice_from_token(s.ident); + td.span = s.span; + td.name_span = s.ident.span; + td.value.struct_ = s; + + if (s.fields.len == 0) { + vec_push(&ctx->errors, err_empty(s.ident.span, ATStruct, td.type)); + } + } + + UnresolvedTypeDef *original = hashmap_get(untypds, &td); + if (original != NULL) { + vec_push(&ctx->errors, err_duplicate_def(original->name_span, td.name_span, ATIdent, original->type)); + vec_take(items, i); + i--; + // Update value to last definition + hashmap_set(untypds, &td); + } else { + hashmap_set(untypds, &td); + } + } + + // Check for type declarations cycles / and resolve type declarations (give them a value) + for (int i = 0; i < items->len; i++) { + if (items->data[i].tag != ATTypeDecl) { + continue; + } + + hashmap_clear(ctx->names); + + AstTypeDecl td = items->data[i].type_decl; + StringSlice name = string_slice_from_token(td.name); + hashmap_set(ctx->names, &name); + bool valid = true; + AstType value = td.value; + + SpanVec spans = vec_init(); + StringSliceVec idents = vec_init(); + vec_push(&spans, td.span); + vec_push(&idents, name); + while (true) { + // Skip indirections + while (value.tag == ATFieldArray || value.tag == ATHeapArray) { + value = *(AstType *)value.array.type; + } + // Value is now an AstIdent. + SpannedStringSlice next = sss_from_token(value.ident.token); + + if (hashmap_set(ctx->names, &next.slice)) { + // We evaluate to a type we've already visited: cycle + + size_t index; + // Loop over idents (members of the cycle), set them as invalid and find the index + // of the first member of the cycle, the members before aren't actually part of it: + // A = B, B = C, C = B, A isn't part of the cycle (B <-> C) and shouldn't be reported + // (but is invalid) + for (size_t i = 0; i < idents.len; i++) { + if (string_slice_equal(&idents.data[i], &next.slice)) { + index = i; + } + + hashmap_set(ctx->typedefs, &(TypeDef){.name = idents.data[i], .value = NULL}); + } + + vec_splice(&spans, 0, index); + vec_splice(&idents, 0, index); + + EvalErrorCycle cycle; + cycle.tag = EETCycle; + cycle.type = ATTypeDecl; + cycle.spans = spans; + cycle.idents = idents; + // reinitialize the vectors to be dropped at the end. + // vec_init doesn't do any allocation so this is free + spans = (SpanVec)vec_init(); + idents = (StringSliceVec)vec_init(); + + EvalError err = {.cycle = cycle}; + + vec_push(&ctx->errors, err); + break; + } + + TypeDef *resolved = hashmap_get(ctx->typedefs, &(TypeDef){.name = next.slice}); + if (resolved != NULL) { + // The type declaration evaluates to a resolved type (a primitif type, or an invalid type) + if (resolved->value == NULL) { + // the type it evaluates to is invalid, so it is too + valid = false; + break; + } + // The type it evaluates to is valid: the type declaration doesn't contain any cycle + break; + } + + UnresolvedTypeDef *unr = hashmap_get(untypds, &(UnresolvedTypeDef){.type = next.slice}); + if (unr == NULL) { // The type evaluates to an unknown identifier + // Report error and set as invalid + vec_push(&ctx->errors, err_unknown(next.span, ATIdent, next.slice)); + valid = false; + break; + } + + if (unr->value.tag == ATStruct) { + // The type declarations evaluates to an (unresolved) struct: it can't cycle + break; + } else { + // The type declarations evaluates to another type declarations: we continue checking + vec_push(&spans, unr->span); + vec_push(&idents, next.slice); + value = unr->value.type; + } + } + + vec_drop(spans); + vec_drop(idents); + + if (!valid) { + // Set invalid + hashmap_set(ctx->typedefs, &(TypeDef){.name = name, .value = NULL}); + } + } + + hashmap_clear(ctx->names); + + // Resolves types (this accepts recursive types) + for (int i = 0; i < items->len; i++) { + if (items->data[i].tag != ATStruct && items->data[i].tag != ATTypeDecl) { + continue; + } + + SpannedStringSlice name; + if (items->data[i].tag == ATStruct) { + name = sss_from_token(items->data[i].struct_.ident); + } else { + name = sss_from_token(items->data[i].type_decl.name); + } + + resolve_type(ctx, name); + } + + Hashmap *checked = hashmap_init(pointer_hash, pointer_equal, NULL, sizeof(StructObject *)); + Hashmap *invalids = hashmap_init(pointer_hash, pointer_equal, NULL, sizeof(StructObject *)); + // Check for recursive types without indirections (infinite size) + for (int i = 0; i < items->len; i++) { + // TypeDecl can't be recursive + if (items->data[i].tag != ATStruct) { + continue; + } + + TypeDef *td = hashmap_get(ctx->typedefs, &(TypeDef){.name = string_slice_from_token(items->data[i].struct_.ident)}); + TypeObject *start = td->value; + StructObject *str = (StructObject *)&start->type.struct_; + + EvalErrorInfiniteStruct err = {.tag = EETInfiniteStruct, .fields = vec_init(), .structs = vec_init()}; + if (check_for_recursion(ctx, &err, checked, invalids, start, str)) { + EvalError e = {.infs = err}; + vec_push(&ctx->errors, e); + }; + hashmap_clear(checked); + } + + // Check structs for duplicate fields + Hashmap *names = hashmap_init(sss_hash, sss_equal, NULL, sizeof(SpannedStringSlice)); + for (int i = 0; i < items->len; i++) { + if (items->data[i].tag != ATStruct) { + continue; + } + + TypeDef *td = hashmap_get(ctx->typedefs, &(TypeDef){.name = string_slice_from_token(items->data[i].struct_.ident)}); + StructObject *str = (StructObject *)&td->value->type.struct_; + for (size_t i = 0; i < str->fields.len; i++) { + Field f = str->fields.data[i]; + SpannedStringSlice *prev = hashmap_get(names, &(SpannedStringSlice){.slice = f.name}); + if (prev != NULL) { + vec_push(&ctx->errors, err_duplicate_def(prev->span, f.name_span, ATField, f.name)); + continue; + } + hashmap_set(names, &(SpannedStringSlice){.slice = f.name, .span = f.name_span}); + } + hashmap_clear(names); + } + hashmap_drop(names); + + hashmap_drop(checked); + hashmap_drop(invalids); + hashmap_drop(untypds); + ctx->unresolved = NULL; +} + +static void resolve_constants(EvaluationContext *ctx) { + AstItemVec *items = ctx->items; + Hashmap *unconsts = hashmap_init(unconst_hash, unconst_equal, NULL, sizeof(UnresolvedConstant)); + Hashmap *names = ctx->names; + Hashmap *constants = ctx->constants; + + // Load unresolved constants into map (and check for duplicates) + for (int i = 0; i < items->len; i++) { + if (items->data[i].tag != ATConstant) { + continue; + } + + AstConstant c = items->data[i].constant; + UnresolvedConstant constant = + {.constant = string_slice_from_token(c.name), .name_span = c.name.span, .span = c.span, .value = c.value.token}; + UnresolvedConstant *original = hashmap_get(unconsts, &constant); + + if (original != NULL) { + vec_push(&ctx->errors, err_duplicate_def(original->name_span, constant.name_span, ATConstant, original->constant)); + vec_take(items, i); + i--; + // Update value to last + hashmap_set(unconsts, &constant); + } else { + hashmap_set(unconsts, &constant); + } + } + + for (size_t i = 0; i < items->len; i++) { + if (items->data[i].tag != ATConstant) { + continue; + } + + UnresolvedConstant *unc = + hashmap_get(unconsts, &(UnresolvedConstant){.constant = string_slice_from_token(items->data[i].constant.name)}); + hashmap_clear(names); + hashmap_set(names, &unc->constant); + Token value = unc->value; + while (value.type == Ident) { + StringSlice ident = string_slice_from_token(value); + Constant *resolved = hashmap_get(constants, &(Constant){.name = ident}); + // If the constant is set to another that is already resolved + if (resolved != NULL) { + if (!resolved->valid) { + // If the constant is invalid, break here, we know we won't be resolving this + break; + } + // We expect a token out of this loop, but we don't have one here, so we make one up that works + // only value.lit and value.type are read + value.type = Number; + value.lit = resolved->value; + break; + } + + if (hashmap_has(names, &ident)) { // Cycle detected on ident + EvalErrorCycle cycle; + cycle.tag = EETCycle; + cycle.type = ATConstant; + cycle.spans = (SpanVec)vec_init(); + cycle.idents = (StringSliceVec)vec_init(); + + // Walk the cycle again, keeping track of the spans, and marking every member + // as invalid + UnresolvedConstant *start = hashmap_get(unconsts, &(UnresolvedConstant){.constant = ident}); + UnresolvedConstant *cur = start; + do { + vec_push(&cycle.spans, cur->span); + vec_push(&cycle.idents, cur->constant); + hashmap_set(constants, &(Constant){.name = cur->constant, .value = 0, .valid = false}); + cur = hashmap_get(unconsts, &(UnresolvedConstant){.constant = string_slice_from_token(cur->value)}); + } while (cur != start); + + EvalError err = {.cycle = cycle}; + + vec_push(&ctx->errors, err); + break; + } + + // Get the constant the current is set to + UnresolvedConstant *c = hashmap_get(unconsts, &(UnresolvedConstant){.constant = ident}); + if (c == NULL) { // Constant doesn't exist + // throw error and mark invalid + vec_push(&ctx->errors, err_unknown(unc->value.span, ATConstant, ident)); + break; + } + + hashmap_set(names, &ident); + value = c->value; + } + + if (value.type == Ident) { // Constant couldn't be resolved + hashmap_set(constants, &(Constant){.name = unc->constant, .value = 0, .valid = false}); + } else { + hashmap_set(constants, &(Constant){.name = unc->constant, .value = value.lit, .valid = true}); + } + } + + hashmap_drop(unconsts); + hashmap_clear(names); +} + +static void resolve_messages(EvaluationContext *ctx) { + AstItemVec *items = ctx->items; + Hashmap *names = hashmap_init(sss_hash, sss_equal, NULL, sizeof(SpannedStringSlice)); + Hashmap *field_names = hashmap_init(sss_hash, sss_equal, NULL, sizeof(SpannedStringSlice)); + + ctx->messages = (MessagesObjectVec)vec_init(); + uint64_t version = ~0; + for (size_t i = 0; i < items->len; i++) { + if (items->data[i].tag == ATVersion) { + AstVersion v = items->data[i].version; + version = get_ast_number_value(ctx, v.version); + continue; + } + if (items->data[i].tag != ATMessages) { + continue; + } + AstMessages m = items->data[i].messages; + SpannedStringSlice name = sss_from_token(m.name); + Attributes attrs = AttrNone; + + MessagesObject res; + res.name = name.slice; + res.messages = (MessageObjectVec)vec_init(); + res.version = version; + + SpannedStringSlice *prev_name = hashmap_get(names, &name); + if (prev_name != NULL) { + vec_push(&ctx->errors, err_duplicate_def(prev_name->span, name.span, ATIdent, name.slice)); + } else { + hashmap_set(names, &name); + } + + for (size_t j = 0; j < m.children.len; j++) { + if (m.children.data[j].tag == ATAttribute) { + AstAttribute attr = m.children.data[j].attribute; + const char *a = attr.ident.lexeme; + uint32_t len = attr.ident.span.len; +#define _case(x) \ + if (strncmp(#x, a, sizeof(#x) - 1 > len ? sizeof(#x) - 1 : len) == 0) { \ + attrs |= Attr_##x; \ + continue; \ + } + _case(versioned); + + // If we get to here none of the above matched + vec_push(&ctx->errors, err_unknown(attr.ident.span, ATAttribute, string_slice_from_token(attr.ident))); +#undef _case + } else { + AstMessage msg = m.children.data[j].message; + + SpannedStringSlice name = sss_from_token(msg.ident); + + SpannedStringSlice *prev_name = hashmap_get(names, &name); + if (prev_name != NULL) { + vec_push(&ctx->errors, err_duplicate_def(prev_name->span, name.span, ATIdent, name.slice)); + } else { + hashmap_set(names, &name); + } + + MessageObject message; + message.name = name.slice; + message.attributes = attrs; + message.fields = (FieldVec)vec_init(); + vec_grow(&message.fields, msg.fields.len); + + for (size_t k = 0; k < msg.fields.len; k++) { + Field f; + f.name = string_slice_from_token(msg.fields.data[k].name); + f.name_span = msg.fields.data[k].name.span; + f.type = ast_type_to_type_obj(ctx, msg.fields.data[k].type); + vec_push(&message.fields, f); + + SpannedStringSlice *prev = hashmap_get(field_names, &(SpannedStringSlice){.slice = f.name}); + if (prev != NULL) { + vec_push(&ctx->errors, err_duplicate_def(prev->span, f.name_span, ATField, f.name)); + continue; + } + hashmap_set(field_names, &(SpannedStringSlice){.slice = f.name, .span = f.name_span}); + } + + hashmap_clear(field_names); + + vec_push(&res.messages, message); + + // Reset attributes after a message + attrs = AttrNone; + } + } + + vec_push(&ctx->messages, res); + version = ~0; + } + + hashmap_drop(names); + hashmap_drop(field_names); +} + +void resolve_additional_type_info(EvaluationContext *ctx) { + // Resolve alignment of all living type objects + Hashmap *seen = hashmap_init(pointer_hash, pointer_equal, NULL, sizeof(TypeObject *)); + for (size_t i = 0; i < ctx->type_objects.len; i++) { + ((TypeObject *)ctx->type_objects.data[i])->align = resolve_alignment(ctx->type_objects.data[i], seen); + hashmap_clear(seen); + } + + // Compute type layouts + Hashmap *layouts = hashmap_init(layout_hash, layout_equal, layout_drop, sizeof(Layout)); + for (size_t i = 0; i < ctx->type_objects.len; i++) { + Layout l = type_layout(ctx->type_objects.data[i]); + hashmap_set(layouts, &l); + } +#define _case(x) \ + { \ + Layout l = type_layout((TypeObject *)&PRIMITIF_##x); \ + hashmap_set(layouts, &l); \ + } + _case(u8); + _case(u16); + _case(u32); + _case(u64); + _case(i8); + _case(i16); + _case(i32); + _case(i64); + _case(char); + _case(bool); +#undef _case + + ctx->layouts = layouts; + + hashmap_drop(seen); +} + +void program_drop(Program p) { + for (size_t i = 0; i < p.type_objects.len; i++) { + TypeObject *ptr = p.type_objects.data[i]; + if (ptr->kind == TypeStruct) { + StructObject *str = (StructObject *)&ptr->type.struct_; + vec_drop(str->fields); + } + free(ptr); + } + vec_drop(p.type_objects); + + hashmap_drop(p.typedefs); + hashmap_drop(p.layouts); + vec_drop(p.messages); +} + +// Resolve statics of an AST (constants and type declarations); +EvaluationResult resolve_statics(AstContext *ctx) { + EvaluationContext ectx; + // resolved constants: value is a number, and the constant may be invalid + ectx.constants = hashmap_init(const_hash, const_equal, NULL, sizeof(Constant)); + ectx.typedefs = hashmap_init(typedef_hash, typedef_equal, NULL, sizeof(TypeDef)); + + // Set of names used to check for cycles + ectx.names = hashmap_init(string_slice_hash, string_slice_equal, NULL, sizeof(StringSlice)); + ectx.unresolved = NULL; + ectx.items = &ctx->root->items.items; + ectx.errors = (EvalErrorVec)vec_init(); + ectx.type_objects = (PointerVec)vec_init(); + + { +#define add_prim(type_name, type_size) \ + do { \ + hashmap_set( \ + ectx.typedefs, \ + &(TypeDef){.name.ptr = #type_name, .name.len = sizeof(#type_name) - 1, .value = (TypeObject *)&PRIMITIF_##type_name} \ + ); \ + } while (0) + add_prim(u8, 1); + add_prim(u16, 2); + add_prim(u32, 4); + add_prim(u64, 8); + add_prim(i8, 1); + add_prim(i16, 2); + add_prim(i32, 4); + add_prim(i64, 8); + add_prim(f32, 4); + add_prim(f64, 8); + add_prim(char, 1); + add_prim(bool, 1); +#undef add_prim + } + + resolve_constants(&ectx); + resolve_types(&ectx); + resolve_messages(&ectx); + resolve_additional_type_info(&ectx); + + hashmap_drop(ectx.names); + hashmap_drop(ectx.constants); + + Program p; + p.typedefs = ectx.typedefs; + p.layouts = ectx.layouts; + p.type_objects = ectx.type_objects; + p.messages = ectx.messages; + + return (EvaluationResult){.program = p, .errors = ectx.errors}; +} diff --git a/ser/eval.h b/ser/eval.h new file mode 100644 index 0000000..5a6271d --- /dev/null +++ b/ser/eval.h @@ -0,0 +1,255 @@ +#ifndef EVAL_H +#define EVAL_H +#include "ast.h" +#include "source.h" +#include "utils.h" +#include "vector_impl.h" + +#include +#include + +#define _ALIGN_1 \ + { .po2 = 0, .mask = 0, .value = 1 } +#define _ALIGN_2 \ + { .po2 = 1, .mask = 1, .value = 2 } +#define _ALIGN_4 \ + { .po2 = 2, .mask = 3, .value = 4 } +#define _ALIGN_8 \ + { .po2 = 3, .mask = 7, .value = 8 } + +#define ALIGN_1 ((Alignment)_ALIGN_1) +#define ALIGN_2 ((Alignment)_ALIGN_2) +#define ALIGN_4 ((Alignment)_ALIGN_4) +#define ALIGN_8 ((Alignment)_ALIGN_8) + +typedef struct { + uint8_t po2; + uint8_t mask; + uint8_t value; +} Alignment; + +static inline uint32_t align(uint32_t v, Alignment align) { return (((v - 1) >> align.po2) + 1) << align.po2; } + +typedef enum { + SizingMax, + SizingFixed, +} Sizing; + +typedef enum { + Primitif_u8, + Primitif_u16, + Primitif_u32, + Primitif_u64, + Primitif_i8, + Primitif_i16, + Primitif_i32, + Primitif_i64, + Primitif_f32, + Primitif_f64, + Primitif_char, + Primitif_bool, +} PrimitifType; + +typedef struct { + Sizing sizing; + uint64_t size; + bool heap; + struct TypeObject *type; +} Array; + +void array_drop(Array a); + +typedef enum { + TypeArray, + TypePrimitif, + TypeStruct, +} TypeKind; + +// Definition of StructObject used by TypeUnion +// Must match with later StructObject +struct StructObject { + StringSlice name; + AnyVec fields; + // Used by codegen_c + bool has_funcs; +}; + +typedef union { + Array array; + PrimitifType primitif; + struct StructObject struct_; +} TypeUnion; + +typedef struct TypeObject { + TypeKind kind; + Alignment align; + TypeUnion type; +} TypeObject; + +void type_drop(TypeObject t); + +typedef struct { + StringSlice name; + Span name_span; + TypeObject *type; +} Field; + +VECTOR_IMPL(Field, FieldVec, field); + +typedef struct { + StringSlice name; + FieldVec fields; + // Used by codegen_c + bool has_funcs; +} StructObject; + +void struct_drop(StructObject s); + +typedef struct { + StringSlice name; + TypeObject *value; +} TypeDef; + +void type_decl_drop(TypeDef t); + +typedef enum : uint32_t { + AttrNone = 0, + Attr_versioned = 1 << 0, +} Attributes; + +static const uint32_t ATTRIBUTES_COUNT = 1; + +typedef struct { + StringSlice name; + FieldVec fields; + Attributes attributes; +} MessageObject; + +void message_drop(MessageObject msg); + +VECTOR_IMPL(MessageObject, MessageObjectVec, message_object, message_drop); + +typedef struct { + StringSlice name; + MessageObjectVec messages; + uint64_t version; +} MessagesObject; + +void messages_drop(MessagesObject msg); + +VECTOR_IMPL(MessagesObject, MessagesObjectVec, messages_object, messages_drop); + +typedef struct { + StringSlice name; + bool valid; + uint64_t value; +} Constant; + +typedef struct { + UInt64Vec indices; + // Size of the field, or 0 if it isn't constant + uint64_t size; + TypeObject *type; +} FieldAccessor; + +void field_accessor_drop(FieldAccessor fa); +FieldAccessor field_accessor_clone(FieldAccessor *fa); + +VECTOR_IMPL(FieldAccessor, FieldAccessorVec, field_accessor, field_accessor_drop); + +typedef struct { + FieldAccessorVec fields; + TypeObject *type; +} Layout; + +Layout type_layout(TypeObject *to); + +void layout_drop(void *l); + +typedef struct { + Hashmap *typedefs; + Hashmap *layouts; + MessagesObjectVec messages; + PointerVec type_objects; +} Program; + +void program_drop(Program p); + +typedef enum { + EETDuplicateDefinition, + EETUnknown, + EETCycle, + EETInfiniteStruct, + EETEmptyType, +} EvalErrorTag; + +typedef struct { + EvalErrorTag tag; + Span first; + Span second; + StringSlice ident; + AstTag type; +} EvalErrorDuplicateDefinition; + +typedef struct { + EvalErrorTag tag; + Span span; + StringSlice ident; + AstTag type; +} EvalErrorUnknown; + +typedef struct { + EvalErrorTag tag; + SpanVec spans; + StringSliceVec idents; + AstTag type; +} EvalErrorCycle; + +typedef struct { + EvalErrorTag tag; + SpannedStringSliceVec structs; + SpannedStringSliceVec fields; +} EvalErrorInfiniteStruct; + +typedef struct { + EvalErrorTag tag; + Span span; + StringSlice ident; + AstTag type; +} EvalErrorEmptyType; + +typedef union { + EvalErrorTag tag; + EvalErrorDuplicateDefinition dup; + EvalErrorUnknown unk; + EvalErrorCycle cycle; + EvalErrorInfiniteStruct infs; + EvalErrorEmptyType empty; +} EvalError; + +void eval_error_drop(EvalError err); +void eval_error_report(Source *src, EvalError *err); + +VECTOR_IMPL(EvalError, EvalErrorVec, eval_error, eval_error_drop); + +typedef struct { + EvalErrorVec errors; + Program program; +} EvaluationResult; + +EvaluationResult resolve_statics(AstContext *ctx); + +extern const TypeObject PRIMITIF_u8; +extern const TypeObject PRIMITIF_u16; +extern const TypeObject PRIMITIF_u32; +extern const TypeObject PRIMITIF_u64; +extern const TypeObject PRIMITIF_i8; +extern const TypeObject PRIMITIF_i16; +extern const TypeObject PRIMITIF_i32; +extern const TypeObject PRIMITIF_i64; +extern const TypeObject PRIMITIF_f32; +extern const TypeObject PRIMITIF_f64; +extern const TypeObject PRIMITIF_char; +extern const TypeObject PRIMITIF_bool; + +#endif diff --git a/ser/gen_vec.c b/ser/gen_vec.c new file mode 100644 index 0000000..83612e9 --- /dev/null +++ b/ser/gen_vec.c @@ -0,0 +1,88 @@ +#include "gen_vec.h" + +#include "assert.h" + +#include + +#define NONE (~0) + +typedef struct { + uint64_t gen; + size_t next_free; +} Entry; + +GenVec genvec_init(size_t data_size, DropFunction drop) { + GenVec res; + res.len = 0; + res.count = 0; + res.size = data_size; + res.entry_size = data_size + sizeof(uint64_t); + res.drop = drop; + res.cap = 0; + res.data = NULL; + res.last_free = NONE; + res.gen = 1; + return res; +} + +static void genvec_grow(GenVec *v, size_t cap) { + if (v->cap >= cap) + return; + cap = v->cap * 2 > cap ? v->cap * 2 : cap; + if (v->cap != 0) { + v->data = realloc(v->data, cap * v->entry_size); + } else { + v->data = malloc(cap * v->entry_size); + } + assert_alloc(v->data); + v->cap = cap; +} + +GenIndex genvec_push(GenVec *v, void *item) { + if (v->last_free == NONE) { + genvec_grow(v, v->len + 1); + byte *ptr = v->data + v->len++ * v->entry_size; + ((Entry *)ptr)->gen = v->gen; + memcpy(ptr + sizeof(Entry), item, v->size); + v->count++; + return (GenIndex){.gen = v->gen, .index = v->len - 1}; + } else { + size_t index = v->last_free; + byte *ptr = v->data + index * v->entry_size; + Entry *entry = (Entry *)ptr; + v->last_free = entry->next_free; + entry->gen = v->gen; + memcpy(ptr + sizeof(Entry), item, v->size); + v->count++; + return (GenIndex){.gen = v->gen, .index = index}; + } +} + +void genvec_remove(GenVec *v, GenIndex idx) { + byte *ptr = v->data + idx.index * v->entry_size; + Entry *entry = (Entry *)ptr; + if (!entry->gen || entry->gen != idx.gen) + return; + entry->gen = 0; + entry->next_free = v->last_free; + v->last_free = idx.index; + if (v->drop != NULL) { + v->drop(ptr + sizeof(Entry)); + } + v->count--; + v->gen++; +} + +void *genvec_get(GenVec *v, GenIndex idx) { + byte *ptr = v->data + idx.index * v->entry_size; + Entry *entry = (Entry *)ptr; + if (!entry->gen || entry->gen != idx.gen) + return NULL; + return ptr + sizeof(Entry); +} + +void genvec_drop(GenVec v) { + if (v.cap >= 0) { + free(v.data); + } +} diff --git a/ser/gen_vec.h b/ser/gen_vec.h new file mode 100644 index 0000000..7188f7e --- /dev/null +++ b/ser/gen_vec.h @@ -0,0 +1,32 @@ +#ifndef GEN_VEC_H +#define GEN_VEC_H +#include +#include +#include +typedef unsigned char byte; +typedef void (*DropFunction)(void *item); + +typedef struct { + size_t size; + size_t entry_size; + size_t cap; + size_t len; + size_t count; + uint64_t gen; + byte *data; + size_t last_free; + DropFunction drop; +} GenVec; + +typedef struct { + uint64_t gen; + size_t index; +} GenIndex; + +GenVec genvec_init(size_t data_size, DropFunction drop); +GenIndex genvec_push(GenVec *v, void *item); +void genvec_remove(GenVec *v, GenIndex idx); +void *genvec_get(GenVec *v, GenIndex idx); +void genvec_drop(GenVec v); + +#endif diff --git a/ser/grammar.bnf b/ser/grammar.bnf new file mode 100644 index 0000000..37bd741 --- /dev/null +++ b/ser/grammar.bnf @@ -0,0 +1,19 @@ +items -> item* ; +item -> align | type_decl | struct | messages | constant; + +type_decl -> "type" IDENT "=" type ";" +align -> "align" "(" number ")" ";" +struct -> "struct" IDENT "{" field ("," field)* ","? "}" ; +messages -> "messages" IDENT "{" message* "}" ; +constant -> "const" IDENT "=" number ";" ; + +field -> IDENT ":" type ; +number -> NUMBER | IDENT ; +message -> IDENT "{" field ("," field)* ","? "}" ; + +type -> IDENT | heap_array | field_array ; +heap_array -> type "&" "[" "]" | type "[" "]" + | type "&" "[" max_size | fixed_size "]" ; +field_array -> type "[" max_size | fixed_size "]" ; +max_size -> "^" number ; +fixed_size -> number ; diff --git a/ser/hashmap.c b/ser/hashmap.c new file mode 100644 index 0000000..34ca951 --- /dev/null +++ b/ser/hashmap.c @@ -0,0 +1,346 @@ +#include "hashmap.h" + +#include "assert.h" +#include "utils.h" + +#include +#include +#include +#include +#include +#include + +#if __BYTE_ORDER__ == __LITTLE_ENDIAN +#define U32TO8_LE(p, v) (*(uint32_t *)(p) = v) +#define U8TO32_LE(p) (*(uint32_t *)(p)) +#else +#define U32TO8_LE(p, v) \ + do { \ + (p)[0] = (uint8_t)((v)); \ + (p)[1] = (uint8_t)((v) >> 8); \ + (p)[2] = (uint8_t)((v) >> 16); \ + (p)[3] = (uint8_t)((v) >> 24); \ + } while (0) + +#define U8TO32_LE(p) (((uint32_t)((p)[0])) | ((uint32_t)((p)[1]) << 8) | ((uint32_t)((p)[2]) << 16) | ((uint32_t)((p)[3]) << 24)) +#endif + +#define ROTL(x, b) (uint32_t)(((x) << (b)) | ((x) >> (32 - (b)))) +#define SIPROUND \ + do { \ + v0 += v1; \ + v1 = ROTL(v1, 5); \ + v1 ^= v0; \ + v0 = ROTL(v0, 16); \ + v2 += v3; \ + v3 = ROTL(v3, 8); \ + v3 ^= v2; \ + v0 += v3; \ + v3 = ROTL(v3, 7); \ + v3 ^= v0; \ + v2 += v1; \ + v1 = ROTL(v1, 13); \ + v1 ^= v2; \ + v2 = ROTL(v2, 16); \ + } while (0) + +// Kinda useless check +_Static_assert(sizeof(uint32_t) == 4, "uint32_t isn't 4 bytes"); + +uint32_t hash(Hasher state, const byte *data, const size_t len) { + uint32_t v0 = 0, v1 = 0, v2 = UINT32_C(0x6c796765), v3 = UINT32_C(0x74656462); + uint32_t k0 = U8TO32_LE((byte *)&state.key), k1 = U8TO32_LE(((byte *)&state.key) + 4); + uint32_t m; + // Pointer to the end of the last 4 byte block + const byte *end = data + len - (len % sizeof(uint32_t)); + const int left = len % sizeof(uint32_t); + uint32_t b = ((uint32_t)len) << 24; + v3 ^= k1; + v2 ^= k0; + v1 ^= k1; + v0 ^= k0; + + for (; data != end; data += 4) { + m = U8TO32_LE(data); + v3 ^= m; + for (int i = 0; i < 2; i++) { + SIPROUND; + } + v0 ^= m; + } + + switch (left) { + case 3: + b |= ((uint32_t)data[2]) << 16; + case 2: + b |= ((uint32_t)data[1]) << 8; + case 1: + b |= ((uint32_t)data[0]); + } + + v3 ^= b; + v2 ^= 0xff; + + for (int i = 0; i < 4; i++) { + SIPROUND; + } + + return v1 ^ v3; +} + +Hasher hasher_init() { + static Hasher HASHER = {.key = UINT64_C(0x5E3514A61CC01657)}; + static uint64_t COUNT = 0; + struct timespec ts; + timespec_get(&ts, TIME_UTC); + ts.tv_nsec += COUNT++; + ts.tv_sec ^= ts.tv_nsec; + uint64_t k; + ((uint32_t *)&k)[0] = hash(HASHER, (byte *)&ts.tv_sec, sizeof(ts.tv_sec)); + ((uint32_t *)&k)[1] = hash(HASHER, (byte *)&ts.tv_nsec, sizeof(ts.tv_nsec)); + // return (Hasher){.key = k}; + // TODO: change that back + return (Hasher){.key = 113223440}; +} + +// Must be a power of 2 +#define HASHMAP_BASE_CAP 64 +#define MAX_ITEMS(cap) (cap / (2)) + +typedef struct { + uint32_t hash; + bool occupied; +} __attribute__((aligned(8))) Bucket; + +Hashmap *hashmap_init(HashFunction hash, EqualFunction equal, DropFunction drop, size_t data_size) { + size_t aligned_size = (((data_size - 1) >> 3) + 1) << 3; + size_t entry_size = sizeof(Bucket) + aligned_size; + byte *alloc = malloc(sizeof(Hashmap)); + byte *buckets = malloc(HASHMAP_BASE_CAP * entry_size); + assert_alloc(alloc); + assert_alloc(buckets); + Hashmap *map = (Hashmap *)alloc; + map->size = data_size; + map->aligned_size = aligned_size; + map->entry_size = sizeof(Bucket) + aligned_size; + map->cap = HASHMAP_BASE_CAP; + map->mask = HASHMAP_BASE_CAP - 1; + map->count = 0; + map->max = MAX_ITEMS(HASHMAP_BASE_CAP); + map->state = hasher_init(); + map->hash = hash; + map->equal = equal; + map->drop = drop; + map->alloc = alloc; + map->buckets = buckets; + map->buckets_end = map->buckets + HASHMAP_BASE_CAP * map->entry_size; + + for (size_t i = 0; i < HASHMAP_BASE_CAP; i++) { + ((Bucket *)buckets)->occupied = false; + buckets += map->entry_size; + } + + return map; +} + +// Return the first empty bucket or the first matching bucket +static inline __attribute__((always_inline)) byte *hashmap_bucket(Hashmap *map, const void *item, uint32_t hash, size_t *rindex) { + int32_t index = hash & map->mask; + byte *ptr = map->buckets + index * map->entry_size; + while (((Bucket *)ptr)->occupied && (((Bucket *)ptr)->hash != hash || !map->equal(item, ptr + sizeof(Bucket)))) { + ptr += map->entry_size; + index++; + if (ptr >= map->buckets_end) { + ptr = map->buckets; + index = 0; + } + } + if (rindex != NULL) { + *rindex = index; + } + return ptr; +} + +static bool hashmap_insert(Hashmap *map, const void *item, uint32_t hash) { + byte *ptr = hashmap_bucket(map, item, hash, NULL); + Bucket *bucket = (Bucket *)ptr; + void *dst = ptr + sizeof(Bucket); + bool replace = bucket->occupied; + if (map->drop != NULL && replace) { + map->drop(dst); + } + + bucket->hash = hash; + bucket->occupied = true; + memcpy(dst, item, map->size); + if (!replace) { + map->count++; + } + return replace; +} + +// Grow hashmap to double the size +static void hashmap_grow(Hashmap *map) { + byte *old_buckets = map->buckets; + size_t old_cap = map->cap; + + map->cap *= 2; + map->mask = map->cap - 1; + map->count = 0; + map->max = MAX_ITEMS(map->cap); + map->buckets = malloc(map->cap * map->entry_size); + assert_alloc(map->buckets); + map->buckets_end = map->buckets + map->cap * map->entry_size; + + for (byte *ptr = map->buckets; ptr < map->buckets_end; ptr += map->entry_size) { + ((Bucket *)ptr)->occupied = false; + } + + byte *ptr = old_buckets; + for (size_t i = 0; i < old_cap; i++) { + Bucket *bucket = (Bucket *)ptr; + void *item = ptr + sizeof(Bucket); + if (bucket->occupied) { + hashmap_insert(map, item, bucket->hash); + } + ptr += map->entry_size; + } + + free(old_buckets); +} + +bool hashmap_set(Hashmap *map, const void *item) { + if (map->count >= map->max) { + hashmap_grow(map); + } + + uint32_t hash = map->hash(map->state, item); + return hashmap_insert(map, item, hash); +} + +void *hashmap_get(Hashmap *map, const void *key) { + uint32_t hash = map->hash(map->state, key); + byte *ptr = hashmap_bucket(map, key, hash, NULL); + Bucket *bucket = (Bucket *)ptr; + void *res = ptr + sizeof(Bucket); + if (!bucket->occupied) { + return NULL; + } else { + return res; + } +} + +bool hashmap_has(Hashmap *map, const void *key) { + uint32_t hash = map->hash(map->state, key); + byte *ptr = hashmap_bucket(map, key, hash, NULL); + Bucket *bucket = (Bucket *)ptr; + + return bucket->occupied; +} + +bool hashmap_take(Hashmap *map, const void *key, void *dst) { + uint32_t hash = map->hash(map->state, key); + byte *ptr = hashmap_bucket(map, key, hash, NULL); + Bucket *bucket = (Bucket *)ptr; + void *item = ptr + sizeof(Bucket); + + if (!bucket->occupied) { + return false; + } + + map->count--; + if (dst == NULL && map->drop != NULL) { + map->drop(item); + } else if (dst != NULL) { + memcpy(dst, item, map->size); + } + + byte *nptr = ptr; + while (true) { + // Kinda jank ? better solution ? + size_t index = (uintptr_t)(ptr - map->buckets) / map->entry_size; + + nptr += map->entry_size; + if (nptr >= map->buckets_end) { + nptr = map->buckets; + } + + while (((Bucket *)nptr)->occupied && (((Bucket *)nptr)->hash & map->mask) > index) { + nptr += map->entry_size; + if (nptr >= map->buckets_end) { + nptr = map->buckets; + } + } + + if (!((Bucket *)nptr)->occupied) { + bucket->occupied = false; + return true; + } + + *bucket = *(Bucket *)nptr; + memcpy(item, nptr + sizeof(Bucket), map->size); + + ptr = nptr; + bucket = (Bucket *)ptr; + item = ptr + sizeof(Bucket); + } +} + +void hashmap_clear(Hashmap *map) { + if (map->count == 0) + return; + + for (byte *ptr = map->buckets; ptr < map->buckets_end; ptr += map->entry_size) { + if (map->drop != NULL) { + map->drop(ptr + sizeof(Bucket)); + } + ((Bucket *)ptr)->occupied = false; + } + map->count = 0; +} + +bool hashmap_iter(Hashmap *map, void *iter_) { + void **iter = (void **)iter_; + if (*iter == NULL) { + if (map->count == 0) { + return false; + } + byte *ptr = map->buckets; + while (!((Bucket *)ptr)->occupied) { + ptr += map->entry_size; + } + *iter = ptr + sizeof(Bucket); + return true; + } + + byte *ptr = ((byte *)(*iter)) - sizeof(Bucket); + ptr += map->entry_size; + if (ptr >= map->buckets_end) + return false; + while (!((Bucket *)ptr)->occupied) { + ptr += map->entry_size; + if (ptr >= map->buckets_end) { + return false; + } + } + + *iter = ptr + sizeof(Bucket); + return true; +} + +void hashmap_drop(Hashmap *map) { + if (map->drop != NULL) { + byte *ptr = map->buckets; + for (size_t i = 0; i < map->cap; i++) { + Bucket *bucket = (Bucket *)ptr; + if (bucket->occupied) { + void *item = ptr + sizeof(Bucket); + map->drop(item); + } + ptr += map->entry_size; + } + } + + free(map->buckets); + free(map->alloc); +} diff --git a/ser/hashmap.h b/ser/hashmap.h new file mode 100644 index 0000000..4952b35 --- /dev/null +++ b/ser/hashmap.h @@ -0,0 +1,87 @@ +#ifndef HASHMAP_H +#define HASHMAP_H + +typedef unsigned char byte; + +#include "gen_vec.h" + +#include +#include +#include + +typedef struct { + uint64_t key; +} Hasher; + +// Create new hasher with a pseudo random state +Hasher hasher_init(); +// Hash given data with hasher +uint32_t hash(Hasher state, const byte *data, const size_t len); + +typedef uint32_t (*HashFunction)(Hasher state, const void *item); +typedef bool (*EqualFunction)(const void *a, const void *b); +typedef void (*DropFunction)(void *item); + +typedef struct { + size_t size; + size_t aligned_size; + size_t entry_size; + size_t cap; + size_t mask; + size_t count; + size_t max; + Hasher state; + byte *buckets; + byte *buckets_end; + byte *alloc; + HashFunction hash; + EqualFunction equal; + DropFunction drop; +} Hashmap; + +typedef struct { + Hashmap *map; + GenVec items; +} StableHashmap; + +// Initialize a new hashmap +Hashmap *hashmap_init(HashFunction hash, EqualFunction equal, DropFunction drop, size_t data_size); +// Insert value in hashmapn returns true if the value was overwritten +bool hashmap_set(Hashmap *map, const void *item); +// Get value of hashmap, return NULL if not found +void *hashmap_get(Hashmap *map, const void *key); +// Take a value from a hashmap and put it into dst +bool hashmap_take(Hashmap *map, const void *key, void *dst); +// Destroy hashmap +void hashmap_drop(Hashmap *map); +// Delete entry from hasmap +static inline __attribute__((always_inline)) bool hashmap_delete(Hashmap *map, const void *key) { + return hashmap_take(map, key, NULL); +} +// Check if hashmap contains key +bool hashmap_has(Hashmap *map, const void *key); +// Clear hashmap of all entries +void hashmap_clear(Hashmap *map); +// Iterate hasmap +bool hashmap_iter(Hashmap *map, void *iter); + +#define impl_hashmap(prefix, type, hash, equal) \ + uint32_t prefix##_hash(Hasher state, const void *_v) { \ + type *v = (type *)_v; \ + hash \ + } \ + bool prefix##_equal(const void *_a, const void *_b) { \ + type *a = (type *)_a; \ + type *b = (type *)_b; \ + equal \ + } \ + _Static_assert(1, "Semicolon required") +#define impl_hashmap_delegate(prefix, type, delegate, accessor) \ + impl_hashmap( \ + prefix, \ + type, \ + { return delegate##_hash(state, &v->accessor); }, \ + { return delegate##_equal(&a->accessor, &b->accessor); } \ + ) + +#endif diff --git a/ser/lexer.c b/ser/lexer.c new file mode 100644 index 0000000..ecc06ee --- /dev/null +++ b/ser/lexer.c @@ -0,0 +1,372 @@ +#include "lexer.h" + +#include "vector.h" + +#include +#include + +typedef struct { + uint32_t start; + uint32_t current; + Source *src; + Location loc; + Location start_loc; + TokenVec tokens; + LexingErrorVec errors; +} Lexer; + +static inline __attribute__((always_inline)) Token +token(Source *src, TokenType type, const char *lexeme, uint32_t len, uint64_t lit, Location loc) { + IF_DEBUG(src->ref_count++); + return (Token){ + .src = src, + .lit = lit, + .span.loc = loc, + .span.len = len, + .type = type, + .lexeme = lexeme, + }; +} +static inline __attribute__((always_inline)) LexingError +lexing_error(Source *src, LexingErrorType type, Location loc, uint32_t len) { + IF_DEBUG(src->ref_count++); + return (LexingError){ + .src = src, + .type = type, + .span.loc = loc, + .span.len = len, + }; +} + +void token_drop(Token t) { IF_DEBUG(t.src->ref_count--); } + +void lexing_error_drop(LexingError e) { IF_DEBUG(e.src->ref_count--); } + +void lexing_result_drop(LexingResult res) { + vec_drop(res.tokens); + vec_drop(res.errors); +} + +static Lexer lexer_init(Source *src) { + TokenVec tokens = vec_init(); + vec_grow(&tokens, 256); + return (Lexer){ + .start = 0, + .current = 0, + .src = src, + .loc = location(1, 0, 0), + .start_loc = location(1, 1, 0), + .tokens = tokens, + .errors = vec_init(), + }; +} + +static void lexer_add_token(Lexer *lex, TokenType type, uint32_t len, uint64_t lit) { + vec_push(&lex->tokens, token(lex->src, type, &lex->src->str[lex->start], len, lit, lex->start_loc)); +} + +static void lexer_add_error(Lexer *lex, LexingErrorType type, uint32_t len) { + vec_push(&lex->errors, lexing_error(lex->src, type, lex->start_loc, len)); +} + +static char lexer_advance(Lexer *lex) { + char c = lex->src->str[lex->current++]; + lex->loc.offset = lex->current; + lex->loc.column++; + if (c == '\n') { + lex->loc.line++; + lex->loc.column = 0; + } + return c; +} + +static bool lexer_match(Lexer *lex, char exp) { + if (lex->current >= lex->src->len) + return false; + if (lex->src->str[lex->current] != exp) + return false; + lexer_advance(lex); + return true; +} + +static bool lexer_match_not(Lexer *lex, char unexp) { + if (lex->current >= lex->src->len) + return false; + if (lex->src->str[lex->current] == unexp) + return false; + lexer_advance(lex); + return true; +} + +static char lexer_peek(Lexer *lex) { return lex->src->str[lex->current]; } + +inline static bool is_digit(char c) { return c >= '0' && c <= '9'; } +inline static uint64_t to_digit(char c) { return c - '0'; } +inline static bool is_ident_start(char c) { return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); } +inline static bool is_ident(char c) { return is_ident_start(c) || is_digit(c) || c == '_'; } + +static void lexer_scan_number(Lexer *lex) { + // Get first digit (the one we already passed) + uint64_t lit = to_digit(lex->src->str[lex->start]); + uint32_t len = 1; + char c = lexer_peek(lex); + bool overflow = false; + while (is_digit(c)) { + uint64_t nlit = lit * 10 + to_digit(c); + if (nlit < lit) { // overflow + overflow = true; + } + lit = nlit; + lexer_advance(lex); + c = lexer_peek(lex); + len++; + } + + if (overflow) { + lexer_add_error(lex, LexingErrorNumberLiteralOverflow, len); + } + + lexer_add_token(lex, Number, len, lit); +} + +static uint32_t u32max(uint32_t a, uint32_t b) { return a > b ? a : b; } + +static inline __attribute__((always_inline)) void lexer_scan_ident(Lexer *lex) { + uint32_t len = 1; + while (is_ident(lexer_peek(lex))) { + lexer_advance(lex); + len++; + } + const char *s = &lex->src->str[lex->start]; +#define handle(x, str) else if (strncmp(str, s, u32max(sizeof(str) - 1, len)) == 0) lexer_add_token(lex, x, len, 0) + if (false) + ; + handle(Messages, "messages"); + handle(Struct, "struct"); + handle(Version, "version"); + handle(Const, "const"); + handle(Type, "type"); + else lexer_add_token(lex, Ident, len, 0); +#undef handle +} + +static void lexer_scan(Lexer *lex) { + char c = lexer_advance(lex); + switch (c) { + case '(': + lexer_add_token(lex, LeftParen, 1, 0); + break; + case ')': + lexer_add_token(lex, RightParen, 1, 0); + break; + case '{': + lexer_add_token(lex, LeftBrace, 1, 0); + break; + case '}': + lexer_add_token(lex, RightBrace, 1, 0); + break; + case '[': + lexer_add_token(lex, LeftBracket, 1, 0); + break; + case ']': + lexer_add_token(lex, RightBracket, 1, 0); + break; + case ',': + lexer_add_token(lex, Comma, 1, 0); + break; + case ';': + lexer_add_token(lex, Semicolon, 1, 0); + break; + case '&': + lexer_add_token(lex, Ampersand, 1, 0); + break; + case '^': + lexer_add_token(lex, Caret, 1, 0); + break; + case ':': + lexer_add_token(lex, Colon, 1, 0); + break; + case '=': + lexer_add_token(lex, Equal, 1, 0); + break; + case '#': + lexer_add_token(lex, Hash, 1, 0); + case '/': + if (lexer_match(lex, '/')) { + while (lexer_match_not(lex, '\n')) + ; + } + break; + case ' ': + case '\t': + case '\n': + case '\r': + break; + default: + if (is_digit(c)) { + lexer_scan_number(lex); + } else if (is_ident_start(c)) { + lexer_scan_ident(lex); + } else { + // Try to merge with the last error if possible + if (lex->errors.len > 0) { + LexingError *last_err = &lex->errors.data[lex->errors.len - 1]; + if (last_err->span.loc.line == lex->loc.line && last_err->type == LexingErrorUnexpectedCharacter && + last_err->span.loc.column + last_err->span.len == lex->start_loc.column) { + last_err->span.len++; + break; + } + } + lexer_add_error(lex, LexingErrorUnexpectedCharacter, 1); + } + } +} + +static void lexer_lex(Lexer *lex) { + while (lex->current < lex->src->len) { + lex->start = lex->current; + lex->start_loc = lex->loc; + lexer_scan(lex); + } + + lex->start = lex->current; + lex->start_loc = lex->loc; + + lexer_add_token(lex, Eof, 0, 0); +} + +static LexingResult lexer_finish(Lexer lex) { + return (LexingResult){ + .errors = lex.errors, + .tokens = lex.tokens, + }; +} + +LexingResult lex(Source *src) { + Lexer lex = lexer_init(src); + lexer_lex(&lex); + return lexer_finish(lex); +} + +void lexing_error_report(LexingError *le) { + ReportSpan span = {.span = le->span, .sev = ReportSeverityError}; +#define report(fmt, ...) source_report(le->src, le->span.loc, ReportSeverityError, &span, 1, NULL, fmt, __VA_ARGS__); + switch (le->type) { + case LexingErrorUnexpectedCharacter: + report("Unexpected character%s '%.*s'", le->span.len > 1 ? "s" : "", le->span.len, &le->src->str[le->span.loc.offset]); + break; + case LexingErrorNumberLiteralOverflow: + report("number literal '%.*s' overflows max value of %lu", le->span.len, &le->src->str[le->span.loc.offset], UINT64_MAX); + break; + default: + break; + } +#undef report +} + +char *token_type_string(TokenType t) { + TokenType types[TOKEN_TYPE_COUNT]; + size_t count = 0; +#define handle(type) \ + if (t & type) \ + types[count++] = type + handle(LeftParen); + handle(RightParen); + handle(LeftBrace); + handle(RightBrace); + handle(LeftBracket); + handle(RightBracket); + handle(Comma); + handle(Semicolon); + handle(Ampersand); + handle(Caret); + handle(Colon); + handle(Equal); + handle(Ident); + handle(Number); + handle(Messages); + handle(Struct); + handle(Version); + handle(Const); + handle(Type); + handle(Eof); +#undef handle + CharVec str = vec_init(); + for (size_t i = 0; i < count; i++) { + if (i == 0) { + } else if (i == count - 1) { + vec_push_array(&str, " or ", 4); + } else { + vec_push_array(&str, ", ", 2); + } + + switch (types[i]) { + case LeftParen: + vec_push_array(&str, "'('", 3); + break; + case RightParen: + vec_push_array(&str, "')'", 3); + break; + case LeftBrace: + vec_push_array(&str, "'{'", 3); + break; + case RightBrace: + vec_push_array(&str, "'}'", 3); + break; + case LeftBracket: + vec_push_array(&str, "'['", 3); + break; + case RightBracket: + vec_push_array(&str, "']'", 3); + break; + case Comma: + vec_push_array(&str, "','", 3); + break; + case Semicolon: + vec_push_array(&str, "';'", 3); + break; + case Ampersand: + vec_push_array(&str, "'&'", 3); + break; + case Caret: + vec_push_array(&str, "'^'", 3); + break; + case Colon: + vec_push_array(&str, "':'", 3); + break; + case Equal: + vec_push_array(&str, "'='", 3); + break; + case Hash: + vec_push_array(&str, "'#'", 3); + break; + case Ident: + vec_push_array(&str, "identifier", 10); + break; + case Number: + vec_push_array(&str, "number literal", 15); + break; + case Messages: + vec_push_array(&str, "keyword messages", 16); + break; + case Struct: + vec_push_array(&str, "keyword struct", 14); + break; + case Version: + vec_push_array(&str, "keyword version", 15); + break; + case Const: + vec_push_array(&str, "keyword const", 13); + break; + case Type: + vec_push_array(&str, "keyword type", 12); + break; + case Eof: + vec_push_array(&str, "end of file", 11); + break; + } + } + + vec_push(&str, '\0'); + return str.data; +} diff --git a/ser/lexer.h b/ser/lexer.h new file mode 100644 index 0000000..de4a931 --- /dev/null +++ b/ser/lexer.h @@ -0,0 +1,107 @@ +#ifndef LEXER_H +#define LEXER_H +#include "source.h" +#include "vector_impl.h" + +#include + +typedef enum : uint32_t { + LeftParen = 1 << 0, + RightParen = 1 << 1, + LeftBrace = 1 << 2, + RightBrace = 1 << 3, + LeftBracket = 1 << 4, + RightBracket = 1 << 5, + Comma = 1 << 6, + Semicolon = 1 << 7, + Ampersand = 1 << 8, + Caret = 1 << 9, + Colon = 1 << 10, + Equal = 1 << 11, + Hash = 1 << 12, + Ident = 1 << 13, + Number = 1 << 14, + Messages = 1 << 15, + Struct = 1 << 16, + Version = 1 << 17, + Const = 1 << 18, + Type = 1 << 19, + Eof = 1 << 20, +} TokenType; + +#define TOKEN_TYPE_COUNT 21 + +typedef struct { + // The type of the token + TokenType type; + // Span of the lexeme (line, columnn, offset, length) + Span span; + // A pointer to the start of the lexeme (not null terminated) + const char *lexeme; + // Pointer to the source object + Source *src; + // In the case of a Number token: the parsed number + uint64_t lit; +} Token; + +typedef enum : uint32_t { + LexingErrorNoError, + LexingErrorUnexpectedCharacter, + LexingErrorNumberLiteralOverflow, +} LexingErrorType; + +typedef struct { + Source *src; + Span span; + LexingErrorType type; +} LexingError; +// Destroy the token +void token_drop(Token t); +// Destroy lexing error +void lexing_error_drop(LexingError e); + +VECTOR_IMPL(Token, TokenVec, token, token_drop); +VECTOR_IMPL(LexingError, LexingErrorVec, lexing_error, lexing_error_drop); + +typedef struct { + TokenVec tokens; + LexingErrorVec errors; +} LexingResult; + +LexingResult lex(Source *src); + +void lexing_result_drop(LexingResult res); + +void lexing_error_report(LexingError *le); + +__attribute__((unused)) static inline const char *token_type_name(TokenType t) { +#define _case(type) \ + case type: \ + return #type + switch (t) { + _case(LeftParen); + _case(RightParen); + _case(LeftBrace); + _case(RightBrace); + _case(LeftBracket); + _case(RightBracket); + _case(Comma); + _case(Semicolon); + _case(Ampersand); + _case(Caret); + _case(Colon); + _case(Equal); + _case(Hash); + _case(Ident); + _case(Number); + _case(Messages); + _case(Struct); + _case(Version); + _case(Const); + _case(Type); + _case(Eof); + } +#undef _case +} +char *token_type_string(TokenType t); +#endif diff --git a/ser/log.c b/ser/log.c new file mode 100644 index 0000000..da24e10 --- /dev/null +++ b/ser/log.c @@ -0,0 +1,143 @@ +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include + +#define BASE_BUFFER_SIZE 1024 +#define SOURCE_BUFFER_SIZE 128 +static char BASE_BUFFER[BASE_BUFFER_SIZE] = {0}; +static char SOURCE_BUFFER[SOURCE_BUFFER_SIZE] = {0}; + +// TODO: mutex +static Logger LOGGER = {.sevs = Info | Warning | Error}; + +void logger_set_fd(FILE *fd) { LOGGER.fd = fd; } + +void logger_enable_severities(LogSeverities sevs) { LOGGER.sevs |= sevs; } + +void logger_disable_severities(LogSeverities sevs) { LOGGER.sevs &= ~sevs; } + +void logger_set_severities(LogSeverities sevs) { LOGGER.sevs = sevs; } + +void logger_init() { LOGGER.initialized = true; } + +// Logging function, should be rarely called by itself (use the log_* macros instead) +// message takes the form: (file:line?) SEVERITY func > fmt ... +// line can be ignored if negative +void _log_severity(LogSeverity sev, const char *func, const char *file, const int line, char *fmt, ...) { + if (!LOGGER.initialized) { + fprintf(stderr, "Trying to log, but the logger hasn't been initialized.\n"); + return; + } + + // Ignore if the logger doesn't have a configured target or if the severity is ignored. + if (LOGGER.fd == NULL || !(LOGGER.sevs & sev)) { + return; + } + + // format source in second half of buffer + int source_len; + if (line >= 0) { + source_len = snprintf(SOURCE_BUFFER, SOURCE_BUFFER_SIZE, "(%s:%d)", file, line); + } else { + source_len = snprintf(SOURCE_BUFFER, SOURCE_BUFFER_SIZE, "(%s)", file); + } + + // Keep track of width for alignment + if (source_len > LOGGER.source_width) { + LOGGER.source_width = source_len; + } + + // "format" severity + const char *sev_str; + switch (sev) { + case Trace: + sev_str = "\033[0;35mTRACE"; + break; + case Debug: + sev_str = "\033[0;34mDEBUG"; + break; + case Info: + sev_str = "\033[0;32mINFO "; + break; + case Warning: + sev_str = "\033[0;33mWARN "; + break; + case Error: + sev_str = "\033[0;31mERROR"; + break; + default: + sev_str = "\033[0;31m?????"; + break; + } + + // no format for func since there is nothing to do + + // SAFETY: func should always come from the __func__ macro, which shouldn't allow buffer overflow. + int func_len = strlen(func); + + // Keep track of width for alignment + if (func_len > LOGGER.func_width) { + LOGGER.func_width = func_len; + } + + // Final string buffer + char *buf = BASE_BUFFER; + int prefix_len = snprintf( + buf, + BASE_BUFFER_SIZE / 2, + "\033[0;2m%-*s %s \033[0;1m%-*s \033[0;2m> ", + LOGGER.source_width, + SOURCE_BUFFER, + sev_str, + LOGGER.func_width, + func + ); + + const char *suffix = "\033[0m\n"; + const int suffix_len = 5; + + // max slice of the buffer used by the message + char *str = buf + prefix_len; + int str_size = BASE_BUFFER_SIZE - prefix_len - suffix_len; // -1 for the trailing newline + + va_list args; + va_start(args, fmt); + int len = vsnprintf(str, str_size, fmt, args); + va_end(args); + + // Make sure we have enough space in the BASE_BUFFER, allocate if we don't + if (len >= str_size) { + buf = malloc(prefix_len + len + suffix_len); + str = buf + prefix_len; + + if (buf == NULL) { + fprintf(stderr, "Couldn't allocate buffer (Out of memory ?), aborting...\n"); + exit(1); + } + + // Copy over prefix into new buffer + memcpy(buf, BASE_BUFFER, prefix_len * sizeof(char)); + + va_start(args, fmt); + vsnprintf(str, len + 1, fmt, args); + va_end(args); + } + + memcpy(buf + prefix_len + len, suffix, suffix_len * sizeof(char)); + + fwrite(buf, 1, prefix_len + len + suffix_len, LOGGER.fd); + +#ifdef LOG_FLUSH + fflush(LOGGER.fd); +#endif + + if (buf != BASE_BUFFER) { + free(buf); + } +} diff --git a/ser/log.h b/ser/log.h new file mode 100644 index 0000000..d097fb0 --- /dev/null +++ b/ser/log.h @@ -0,0 +1,53 @@ +#ifndef LOG_H +#define LOG_H + +#include +#include +#include + +// Bit field of severities +typedef uint32_t LogSeverities; + +// The logger +typedef struct { + bool initialized; + FILE *fd; + LogSeverities sevs; + int source_width; + int func_width; +} Logger; + +// A message's severity Error > Warning > Info > Debug > Trace +typedef enum : LogSeverities { + Trace = 1 << 0, + Debug = 1 << 1, + Info = 1 << 2, + Warning = 1 << 3, + Error = 1 << 4, +} LogSeverity; + +// Needs to be here but log_* macros should be used instead +void _log_severity(LogSeverity sev, const char *func, const char *file, const int line, char *fmt, ...); + +// Set the file desciptor for the logger +void logger_set_fd(FILE *fd); +void logger_enable_severities(LogSeverities sevs); +void logger_disable_severities(LogSeverities sevs); +void logger_set_severities(LogSeverities sevs); +void logger_init(); + +#ifdef LOG_DISABLE +#define log_trace(...) (void)0 +#define log_debug(...) (void)0 +#define log_info(...) (void)0 +#define log_warn(...) (void)0 +#define log_error(...) (void)0 +#else +#define log_trace(...) _log_severity(Trace, __func__, __FILE__, __LINE__, __VA_ARGS__) +#define log_debug(...) _log_severity(Debug, __func__, __FILE__, __LINE__, __VA_ARGS__) +#define log_info(...) _log_severity(Info, __func__, __FILE__, __LINE__, __VA_ARGS__) +#define log_warn(...) _log_severity(Warning, __func__, __FILE__, __LINE__, __VA_ARGS__) +#define log_error(...) _log_severity(Error, __func__, __FILE__, __LINE__, __VA_ARGS__) +#endif + +#endif diff --git a/ser/macro_utils.h b/ser/macro_utils.h new file mode 100644 index 0000000..9254270 --- /dev/null +++ b/ser/macro_utils.h @@ -0,0 +1,67 @@ +#ifndef MACRO_UTILS_H +#define MACRO_UTILS_H + +#define CALL(m, ...) m(__VA_ARGS__) + +#define EMPTY() + +#define EVAL(...) EVAL32(__VA_ARGS__) +#define EVAL1024(...) EVAL512(EVAL512(__VA_ARGS__)) +#define EVAL512(...) EVAL256(EVAL256(__VA_ARGS__)) +#define EVAL256(...) EVAL128(EVAL128(__VA_ARGS__)) +#define EVAL128(...) EVAL64(EVAL64(__VA_ARGS__)) +#define EVAL64(...) EVAL32(EVAL32(__VA_ARGS__)) +#define EVAL32(...) EVAL16(EVAL16(__VA_ARGS__)) +#define EVAL16(...) EVAL8(EVAL8(__VA_ARGS__)) +#define EVAL8(...) EVAL4(EVAL4(__VA_ARGS__)) +#define EVAL4(...) EVAL2(EVAL2(__VA_ARGS__)) +#define EVAL2(...) EVAL1(EVAL1(__VA_ARGS__)) +#define EVAL1(...) __VA_ARGS__ +#define EVAL0(...) + +#define SND(a, b, ...) b +#define FST(a, ...) a +#define CAT(a, b) a##b +#define STR(a) #a + +#define PROBE() ~, 1 +#define IS_PROBE(...) SND(__VA_ARGS__, 0) +// _FAST_NOT(0) -> 1 _FAST_NOT(1) -> 0 +#define _FAST_NOT(x) CAT(_FAST_NOT_, x)() +#define _FAST_NOT_0() 1 +#define _FAST_NOT_1() 0 +// NOT(0) -> 1 NOT(...) -> 0 +#define NOT(x) IS_PROBE(CAT(_NOT_, x)) +#define _NOT_0 PROBE() +// BOOL(0) -> 0 BOOL(...) -> 1 +#define BOOL(x) _FAST_NOT(NOT(x)) + +// Same as EVAL1 but different meaning +#define KEEP(...) __VA_ARGS__ +// Drop / Delete the arguments +#define DROP(...) + +#define IF_ELSE(c) FAST_IF_ELSE(BOOL(c)) +// IF_ELSE if c is know to be 0 or 1 +#define FAST_IF_ELSE(c) CAT(_IF_ELSE_, c) +#define _IF_ELSE_0(...) KEEP +#define _IF_ELSE_1(...) __VA_ARGS__ DROP + +#define HAS_ARGS(...) BOOL(FST(_HAS_ARGS_ __VA_ARGS__)()) +#define _HAS_ARGS_() 0 +#define IF_ELSE_ARGS(...) FAST_IF_ELSE(HAS_ARGS(__VA_ARGS__)) + +#define DEFER1(x) x EMPTY() +#define DEFER2(x) x EMPTY EMPTY()() +#define DEFER3(x) x EMPTY EMPTY EMPTY()()() +#define DEFER4(x) x EMPTY EMPTY EMPTY EMPTY()()()() +#define DEFER5(x) x EMPTY EMPTY EMPTY EMPTY EMPTY()()()()() + +#define MAP(m, fst, ...) m(fst) __VA_OPT__(DEFER1(_MAP)()(DEFER1(m), __VA_ARGS__)) +#define _MAP() MAP + +#define REVERSE(...) IF_ELSE_ARGS(__VA_ARGS__)(EVAL(_REVERSE(__VA_ARGS__)))() +#define _REVERSE(a, ...) __VA_OPT__(DEFER1(__REVERSE)()(__VA_ARGS__), ) a +#define __REVERSE() _REVERSE + +#endif diff --git a/ser/main.c b/ser/main.c new file mode 100644 index 0000000..908cb40 --- /dev/null +++ b/ser/main.c @@ -0,0 +1,143 @@ +#include "ast.h" +#include "codegen_c.h" +#include "codegen_python.h" +#include "hashmap.h" +#include "lexer.h" +#include "log.h" +#include "parser.h" +#include "source.h" + +#include +#include + +void abort_error(uint32_t error_count) { + fprintf(stderr, "\033[1;91merror\033[0m: aborting due to previous error%s\n", error_count > 1 ? "s" : ""); + exit(1); +} + +typedef enum { + BackendC, + BackendPython, +} Backend; + +static Hashmap *backend_map = NULL; + +typedef struct { + StringSlice name; + Backend b; +} BackendString; + +impl_hashmap_delegate(backend, BackendString, string_slice, name); + +Backend parse_backend(const char *b) { + if (backend_map == NULL) { + backend_map = hashmap_init(backend_hash, backend_equal, NULL, sizeof(BackendString)); + hashmap_set(backend_map, &(BackendString){.name = STRING_SLICE("c"), .b = BackendC}); + hashmap_set(backend_map, &(BackendString){.name = STRING_SLICE("python"), .b = BackendPython}); + } + + BackendString *backend = hashmap_get(backend_map, &(BackendString){.name.ptr = b, .name.len = strlen(b)}); + if (backend == NULL) { + log_error("Unknown backend '%s'", b); + exit(1); + } + return backend->b; +} + +int main(int argc, char **argv) { + logger_set_fd(stderr); + logger_enable_severities(Info | Warning | Error); + logger_init(); + + if (argc != 4) { + fprintf(stderr, "Expected 3 arguments: ser \n"); + } + + char *source_path = argv[1]; + Backend back = parse_backend(argv[2]); + char *output = argv[3]; + + Source src; + SourceError serr = source_open(source_path, &src); + + if (serr != SourceErrorNoError) { + log_error("Error when opening or reading source"); + exit(1); + } + + LexingResult lexing_result = lex(&src); + if (lexing_result.errors.len > 0) { + for (size_t i = 0; i < lexing_result.errors.len; i++) { + lexing_error_report(&lexing_result.errors.data[i]); + } + abort_error(lexing_result.errors.len); + } + vec_drop(lexing_result.errors); + + ParsingResult parsing_result = parse(lexing_result.tokens); + + if (parsing_result.errors.len > 0) { + for (size_t i = 0; i < parsing_result.errors.len; i++) { + parsing_error_report(&src, &parsing_result.errors.data[i]); + } + abort_error(parsing_result.errors.len); + } + vec_drop(parsing_result.errors); + + EvaluationResult evaluation_result = resolve_statics(&parsing_result.ctx); + if (evaluation_result.errors.len > 0) { + for (size_t i = 0; i < evaluation_result.errors.len; i++) { + eval_error_report(&src, &evaluation_result.errors.data[i]); + } + abort_error(evaluation_result.errors.len); + } + vec_drop(evaluation_result.errors); + + switch (back) { + case BackendC: { + char *basename; + { + char *last_slash = strrchr(output, '/'); + if (last_slash == NULL) { + basename = output; + } else { + basename = last_slash + 1; + } + } + char *header_path = msprintf("%s.h", output); + char *source_path = msprintf("%s.c", output); + + FileWriter header = file_writer_init(header_path); + FileWriter source = file_writer_init(source_path); + + codegen_c((Writer *)&header, (Writer *)&source, basename, &evaluation_result.program); + + file_writer_drop(header); + file_writer_drop(source); + + free(source_path); + free(header_path); + break; + } + case BackendPython: { + FileWriter source = file_writer_init(output); + + codegen_python((Writer *)&source, &evaluation_result.program); + + file_writer_drop(source); + break; + } + default: + log_error("What the fuck ?"); + exit(1); + } + + program_drop(evaluation_result.program); + ast_drop(parsing_result.ctx); + vec_drop(lexing_result.tokens); + source_drop(src); + + if (backend_map != NULL) { + hashmap_drop(backend_map); + } +} diff --git a/ser/parser.c b/ser/parser.c new file mode 100644 index 0000000..2f66355 --- /dev/null +++ b/ser/parser.c @@ -0,0 +1,350 @@ +#include "parser.h" + +#include "ast.h" +#include "lexer.h" +#include "vector.h" + +#include +#include + +typedef struct { + TokenVec tokens; + ParsingErrorVec errors; + AstContext ctx; + uint32_t current; +} Parser; + +static Parser parser_init(TokenVec tokens) { + return (Parser){ + .tokens = tokens, + .ctx = ast_init(), + .current = 0, + .errors = vec_init(), + }; +} + +inline static ParsingError err_expected(TokenType type, Span span) { + return (ParsingError){.span = span, .type = ParsingErrorUnexpectedToken, .data.type = type}; +} + +inline static void add_error(Parser *p, ParsingError err) { vec_push(&p->errors, err); } + +inline static Token peek(Parser *p) { return p->tokens.data[p->current]; } + +inline static Token previous(Parser *p) { return p->tokens.data[p->current - 1]; } + +static bool check(Parser *p, TokenType type) { + if (peek(p).type == Eof) { + return type == Eof; + } + return peek(p).type == type; +} + +static Token advance(Parser *p) { + if (peek(p).type != Eof) + p->current++; + return previous(p); +} + +static bool match(Parser *p, TokenType t) { + if (check(p, t)) { + advance(p); + return true; + } + return false; +} + +static void skip_until(Parser *p, TokenType type) { + while ((peek(p).type & (type | Eof)) == 0) { + advance(p); + } +} + +static bool consume(Parser *p, TokenType t, Token *res) { + if (peek(p).type == t) { + if (res != NULL) { + *res = advance(p); + } else { + advance(p); + } + return true; + } + add_error(p, err_expected(t, peek(p).span)); + return false; +} + +#define bubble(...) \ + if (!(__VA_ARGS__)) { \ + return false; \ + } + +static Location parser_loc(Parser *p) { return p->tokens.data[p->current].span.loc; } + +static inline Span span_end(Parser *p, Location start) { + Span prev = previous(p).span; + return span_from_to( + start, + (Location){.line = prev.loc.line, .column = prev.loc.column + prev.len, .offset = prev.loc.offset + prev.len} + ); +} + +static bool parse_number(Parser *p, AstNumber *res) { + Token t = advance(p); + if (t.type == Number) { + *res = ast_number(p->ctx, t.span, t); + return true; + } + if (t.type == Ident) { + *res = ast_number(p->ctx, t.span, t); + return true; + } + add_error(p, err_expected(Number | Ident, t.span)); + return false; +} + +static bool parse_ident(Parser *p, AstIdent *res) { + Token t = advance(p); + if (t.type == Ident) { + *res = ast_ident(p->ctx, t.span, t); + return true; + } + add_error(p, err_expected(Ident, t.span)); + return false; +} + +static bool parse_size(Parser *p, AstSize *res) { + Location start = parser_loc(p); + if (check(p, RightBracket)) { + *res = ast_no_size(p->ctx, span_end(p, start)); + return true; + } + AstNumber size; + if (match(p, Caret)) { + bubble(parse_number(p, &size)); + *res = ast_max_size(p->ctx, span_end(p, start), size); + return true; + } + bubble(parse_number(p, &size)); + *res = ast_fixed_size(p->ctx, span_end(p, start), size); + return true; +} + +static bool parse_type(Parser *p, AstType *res) { + bubble(parse_ident(p, &res->ident)); + + Location start = parser_loc(p); + TokenType next = peek(p).type; + while (next == Ampersand || next == LeftBracket) { + AstType *type = arena_alloc(&p->ctx.alloc, sizeof(AstType)); + *type = *res; + AstSize size; + bool heap = match(p, Ampersand); + bubble(consume(p, LeftBracket, NULL)); + bubble(parse_size(p, &size)); + bubble(consume(p, RightBracket, NULL)); + if (heap || size.tag == ATNoSize) { + res->array = ast_heap_array(p->ctx, span_end(p, start), type, size); + } else { + res->array = ast_field_array(p->ctx, span_end(p, start), type, size); + } + next = peek(p).type; + } + return true; +} + +static bool parse_field(Parser *p, AstField *res) { + Token name; + AstType type; + Location start = parser_loc(p); + bubble(consume(p, Ident, &name)); + bubble(consume(p, Colon, NULL)); + bubble(parse_type(p, &type)); + *res = ast_field(p->ctx, span_end(p, start), name, type); + return true; +} + +static bool parse_message(Parser *p, AstMessage *res) { + Token name; + AstFieldVec fields = vec_init(); + Location start = parser_loc(p); + bubble(consume(p, Ident, &name)); + bubble(consume(p, LeftBrace, NULL)); + + AstField f; + do { + if (check(p, RightBrace)) { + break; + } + if (parse_field(p, &f)) { + vec_push(&fields, f); + } else { + skip_until(p, Comma | Ident | RightBrace); + } + } while (match(p, Comma)); + bubble(consume(p, RightBrace, NULL)); + *res = ast_message(p->ctx, span_end(p, start), name, fields); + return true; +} + +static bool parse_attribute(Parser *p, AstAttribute *res) { + Token ident; + Location start = parser_loc(p); + bubble(consume(p, Hash, NULL)); + bubble(consume(p, LeftBracket, NULL)); + bubble(consume(p, Ident, &ident)); + bubble(consume(p, RightBracket, NULL)); + *res = ast_attribute(p->ctx, span_end(p, start), ident); + return true; +} + +static bool parse_attribute_or_message(Parser *p, AstAttributeOrMessage *res) { + if (check(p, Hash)) { + return parse_attribute(p, &res->attribute); + } else if (check(p, Ident)) { + return parse_message(p, &res->message); + } else { + vec_push(&p->errors, err_expected(Hash | Ident, peek(p).span)); + return false; + } +} + +static bool parse_version(Parser *p, AstVersion *res) { + AstNumber ver; + Location start = parser_loc(p); + bubble(consume(p, Version, NULL)); + bubble(consume(p, LeftParen, NULL)); + bubble(parse_number(p, &ver)); + bubble(consume(p, RightParen, NULL)); + bubble(consume(p, Semicolon, NULL)); + *res = ast_version(p->ctx, span_end(p, start), ver); + return true; +} + +static bool parse_struct(Parser *p, AstStruct *res) { + Token name; + AstFieldVec fields = vec_init(); + Location start = parser_loc(p); + bubble(consume(p, Struct, NULL)); + bubble(consume(p, Ident, &name)); + bubble(consume(p, LeftBrace, NULL)); + + AstField f; + do { + if (check(p, RightBrace)) { + break; + } + if (parse_field(p, &f)) { + vec_push(&fields, f); + } else { + skip_until(p, Comma | Ident | RightBrace); + } + } while (match(p, Comma)); + bubble(consume(p, RightBrace, NULL)); + *res = ast_struct(p->ctx, span_end(p, start), name, fields); + return true; +} + +static bool parse_type_decl(Parser *p, AstTypeDecl *res) { + Token name; + AstType type; + Location start = parser_loc(p); + bubble(consume(p, Type, NULL)); + bubble(consume(p, Ident, &name)); + bubble(consume(p, Equal, NULL)); + bubble(parse_type(p, &type)); + bubble(consume(p, Semicolon, NULL)); + *res = ast_type_decl(p->ctx, span_end(p, start), name, type); + return true; +} + +static bool parse_messages(Parser *p, AstMessages *res) { + AstAttributeOrMessageVec children = vec_init(); + AstAttributeOrMessage child; + Token name; + Location start = parser_loc(p); + bubble(consume(p, Messages, NULL)); + bubble(consume(p, Ident, &name)); + bubble(consume(p, LeftBrace, NULL)); + while (!match(p, RightBrace)) { + if (parse_attribute_or_message(p, &child)) { + vec_push(&children, child); + } else { + skip_until(p, RightBrace | Hash | Ident); + } + } + *res = ast_messages(p->ctx, span_end(p, start), name, children); + return true; +} + +static bool parse_constant(Parser *p, AstConstant *res) { + Token name; + AstNumber value; + Location start = parser_loc(p); + bubble(consume(p, Const, NULL)); + bubble(consume(p, Ident, &name)); + bubble(consume(p, Equal, NULL)); + bubble(parse_number(p, &value)); + bubble(consume(p, Semicolon, NULL)); + *res = ast_constant(p->ctx, span_end(p, start), name, value); + return true; +} + +static bool parse_item(Parser *p, AstItem *res) { + switch (peek(p).type) { + case Version: + return parse_version(p, &res->version); + case Struct: + return parse_struct(p, &res->struct_); + case Type: + return parse_type_decl(p, &res->type_decl); + case Messages: + return parse_messages(p, &res->messages); + case Const: + return parse_constant(p, &res->constant); + default: + // TODO: error handling + advance(p); + return false; + } +} + +static bool parse_items(Parser *p, AstItems *res) { + AstItemVec items = vec_init(); + AstItem item; + Location start = parser_loc(p); + while (!check(p, Eof)) { + if (parse_item(p, &item)) { + vec_push(&items, item); + } else { + skip_until(p, Version | Struct | Type | Messages | Const); + } + } + *res = ast_items(p->ctx, span_end(p, start), items); + return true; +} + +ParsingResult parse(TokenVec vec) { + Parser p = parser_init(vec); + AstNode *items = arena_alloc(&p.ctx.alloc, sizeof(AstNode)); + parse_items(&p, &items->items); + p.ctx.root = items; + return (ParsingResult){.ctx = p.ctx, .errors = p.errors}; +} + +void parsing_error_report(Source *src, ParsingError *err) { + ReportSpan span = {.span = err->span, .sev = ReportSeverityError}; +#define report(fmt, ...) source_report(src, err->span.loc, ReportSeverityError, &span, 1, NULL, fmt, __VA_ARGS__); + switch (err->type) { + case ParsingErrorUnexpectedToken: { + char *type = token_type_string(err->data.type); + span.message = msprintf("expected %s", type); + report("Expected %s, found '%.*s'", type, err->span.len, &src->str[err->span.loc.offset]); + free((char *)span.message); + free(type); + break; + } + default: + break; + } +#undef report +} diff --git a/ser/parser.h b/ser/parser.h new file mode 100644 index 0000000..c469d58 --- /dev/null +++ b/ser/parser.h @@ -0,0 +1,35 @@ +#ifndef PARSER_H +#define PARSER_H +#include "ast.h" +#include "lexer.h" +#include "source.h" +#include "vector.h" +#include "vector_impl.h" + +typedef union { + TokenType type; +} ParsingErrorData; + +typedef enum { + ParsingErrorNoError, + ParsingErrorUnexpectedToken, +} ParsingErrorType; + +typedef struct { + Span span; + ParsingErrorType type; + ParsingErrorData data; +} ParsingError; + +VECTOR_IMPL(ParsingError, ParsingErrorVec, parsing_error); + +typedef struct { + AstContext ctx; + ParsingErrorVec errors; +} ParsingResult; + +ParsingResult parse(TokenVec vec); + +void parsing_error_report(Source *src, ParsingError *err); + +#endif diff --git a/ser/source.c b/ser/source.c new file mode 100644 index 0000000..1c296e7 --- /dev/null +++ b/ser/source.c @@ -0,0 +1,297 @@ +#include "source.h" + +#include "assert.h" +#include "vector.h" + +#include +#include +#include +#include + +uint32_t sss_hash(Hasher state, const void *v) { + SpannedStringSlice *sss = (SpannedStringSlice *)v; + return string_slice_hash(state, &sss->slice); +} +bool sss_equal(const void *a, const void *b) { + SpannedStringSlice *sa = (SpannedStringSlice *)a; + SpannedStringSlice *sb = (SpannedStringSlice *)b; + return string_slice_equal(sa, sb); +} + +Source source_init(const char *str, uint32_t len) { + char *ptr = malloc(len + 1); + assert_alloc(ptr); + strncpy(ptr, str, len); + ptr[len] = '\0'; + // Will initlalize ref_count to 0 in DEBUG mode as well + return (Source){.str = ptr, .len = len, .path = NULL}; +} +SourceError source_from_file(FILE *f, Source *src) { + fseek(f, 0, SEEK_END); + uint64_t len = ftell(f); + fseek(f, 0, SEEK_SET); + char *ptr = malloc(len + 1); + + if (fread(ptr, 1, len, f) != len) { + return SourceErrorReadFailed; + } + + IF_DEBUG(src->ref_count = 0); + src->str = ptr; + src->len = len; + src->path = NULL; + return SourceErrorNoError; +} +SourceError source_open(const char *path, Source *src) { + FILE *f = fopen(path, "r"); + if (f == NULL) { + return SourceErrorOpenFailed; + } + + SourceError err = source_from_file(f, src); + fclose(f); + + if (err == SourceErrorNoError) { + char *p = strdup(path); + assert_alloc(p); + src->path = p; + } + + return err; +} +void source_drop(Source src) { + IF_DEBUG({ + if (src.ref_count > 0) { + log_error("Trying to destroy currently used source, leaking instead"); + return; + } + }); + if (src.path != NULL) { + free((char *)src.path); + } + free((char *)src.str); +} + +int span_compare(const void *sa, const void *sb) { + Span *a = (Span *)sa; + Span *b = (Span *)sb; + int line = a->loc.line - b->loc.line; + if (line != 0) + return line; + int column = b->loc.column - a->loc.column; + if (column != 0) + return column; + return a->len - b->len; +} + +static int report_span_compare(const void *va, const void *vb) { + ReportSpan *a = (ReportSpan *)va; + ReportSpan *b = (ReportSpan *)vb; + return span_compare(&a->span, &b->span); +} + +void source_report( + const Source *src, + Location loc, + ReportSeverity sev, + const ReportSpan *pspans, + uint32_t span_count, + const char *help, + const char *fmt, + ... +) { + va_list args; + va_start(args, fmt); + int len = vsnprintf(NULL, 0, fmt, args); + va_end(args); + char *message = malloc(len + 1); + assert_alloc(message); + va_start(args, fmt); + vsnprintf(message, len + 1, fmt, args); + va_end(args); + + ReportSpanVec spans = vec_init(); + vec_push_array(&spans, pspans, span_count); + qsort(spans.data, spans.len, sizeof(ReportSpan), report_span_compare); + + const char *s; + switch (sev) { + case ReportSeverityError: + s = "\033[91merror"; + break; + case ReportSeverityWarning: + s = "\033[93mwarning"; + break; + case ReportSeverityNote: + s = "\033[92mnote"; + break; + default: + s = "?????"; + break; + } + const char *file; + if (src->path == NULL) { + file = ""; + } else { + file = src->path; + } + + uint32_t last_line, first_line; + if (spans.len > 0) { + last_line = spans.data[spans.len - 1].span.loc.line; + first_line = spans.data[0].span.loc.line; + } else { + last_line = loc.line; + first_line = loc.line; + } + + uint32_t pad = floor(log10(last_line)) + 2; + + fprintf( + stderr, + "\033[1m%s\033[0;1m: %s\n%*s\033[94m--> \033[0m%s:%d:%d\n%*s\033[1;94m|\n", + s, + message, + pad - 1, + "", + file, + loc.line, + loc.column, + pad, + "" + ); + + free(message); + + // The line of the span + StyledString line_str = styled_string_init(); + // Extra lines used when no more space in the sub + StyledStringVec sub_strs = vec_init(); + uint32_t line_length; + uint32_t offset; + + last_line = first_line - 1; + for (uint32_t i = 0; i < spans.len; i++) { + ReportSpan rspan = spans.data[i]; + Span span = rspan.span; + + offset = span.loc.offset - span.loc.column; + uint32_t line_end_off = offset; + while (line_end_off < src->len && src->str[line_end_off] != '\n') { + line_end_off++; + } + + uint32_t line_delta = span.loc.line - last_line; + + line_length = line_end_off - offset; + last_line = span.loc.line; + + styled_string_clear(&line_str); + vec_clear(&sub_strs); + vec_push(&sub_strs, styled_string_init()); + styled_string_set(&line_str, 0, NULL, &src->str[offset], line_length); + + while (i < spans.len && spans.data[i].span.loc.line == last_line) { + ReportSpan rspan = spans.data[i]; + Span span = rspan.span; + ReportSeverity span_sev = rspan.sev; + + const char *sev_style = "\033[1;97m"; + char underline = ' '; + switch (span_sev) { + case ReportSeverityError: + sev_style = "\033[1;91m"; + underline = '^'; + break; + case ReportSeverityWarning: + sev_style = "\033[1;93m"; + underline = '^'; + break; + case ReportSeverityNote: + sev_style = "\033[1;94m"; + underline = '-'; + break; + } + + styled_string_set_style(&line_str, span.loc.column, sev_style, span.len); + styled_string_set_style(sub_strs.data, span.loc.column, sev_style, span.len); + styled_string_fill(&sub_strs.data[0], span.loc.column, underline, span.len); + + // Not a loop, but I want break; + while (rspan.message != NULL) { + int mlen = strlen(rspan.message); + size_t index = span.loc.column + span.len + 1; + if (styled_string_available_space(&sub_strs.data[0], index, mlen + 1) > mlen) { + styled_string_set(&sub_strs.data[0], index, sev_style, rspan.message, mlen); + // We got the message in + break; + } + + index = span.loc.column; + + // We never put any message on the second sub string, so it needs to exist if we put one on the third + if (sub_strs.len == 1) { + vec_push(&sub_strs, styled_string_init()); + } + + // Start looking at the third sub line + size_t line = 2; + while (true) { + // The line doesn't exist yet: it is available + if (line >= sub_strs.len) { + vec_push(&sub_strs, styled_string_init()); + break; + } + // Check if the subline is ok + if (styled_string_available_space(&sub_strs.data[line], index, mlen + 1) > mlen) { + break; + } + // We couldn't find any space, continue on the next line. + line++; + } + + for (size_t l = 1; l < line; l++) { + styled_string_set(&sub_strs.data[l], index, sev_style, "|", 1); + } + + styled_string_set(&sub_strs.data[line], index, sev_style, rspan.message, mlen); + break; + } + + i++; + } + // We exited the loop, i points to a span on the next line or to the end of spans + // Se decrement it because it'll get reincremented by the outer for loop + i--; + + // Print elipsies if we skipped more than a line + if (line_delta > 2) { + fprintf(stderr, "\033[1;94m...\n"); + } else if (line_delta > 1) { + uint32_t off_end = offset - 1; + uint32_t off_start = off_end; + while (src->str[off_start - 1] != '\n' && off_start > 0) { + off_start--; + } + uint32_t len = off_end - off_start; + fprintf(stderr, "\033[1;94m%*d |\033[0m %.*s\n", pad - 1, last_line - 1, len, &src->str[off_start]); + } + + char *line = styled_string_build(&line_str); + fprintf(stderr, "\033[1;94m%*d |\033[0m %s\n", pad - 1, last_line, line); + free(line); + for (size_t i = 0; i < sub_strs.len; i++) { + line = styled_string_build(&sub_strs.data[i]); + fprintf(stderr, "%*s\033[1;94m|\033[0m %s\n", pad, "", line); + free(line); + } + } + + styled_string_drop(line_str); + vec_drop(sub_strs); + vec_drop(spans); + + if (help != NULL) { + fprintf(stderr, "\033[1;96mhelp\033[0m: %s\n", help); + } +} diff --git a/ser/source.h b/ser/source.h new file mode 100644 index 0000000..05a921e --- /dev/null +++ b/ser/source.h @@ -0,0 +1,91 @@ +#ifndef SOURCE_H +#define SOURCE_H +#include "utils.h" +#include "vector_impl.h" + +#include +#include + +typedef struct { + uint32_t line; + uint32_t column; + uint32_t offset; +} Location; + +typedef struct { + Location loc; + uint32_t len; +} Span; + +typedef struct { + StringSlice slice; + Span span; +} SpannedStringSlice; + +int span_compare(const void *sa, const void *sb); + +bool sss_equal(const void *a, const void *b); +uint32_t sss_hash(Hasher state, const void *v); + +VECTOR_IMPL(Span, SpanVec, span); +VECTOR_IMPL(SpannedStringSlice, SpannedStringSliceVec, spanned_string_slice); + +typedef struct { + // The string content + const char *str; + // Path of the source file if available + const char *path; + uint32_t len; + IF_DEBUG(uint32_t ref_count;) +} Source; + +typedef enum : uint32_t { + SourceErrorNoError = 0, + SourceErrorReadFailed = 1, + SourceErrorOpenFailed = 2, +} SourceError; + +typedef enum { + ReportSeverityError, + ReportSeverityWarning, + ReportSeverityNote, +} ReportSeverity; + +typedef struct { + Span span; + ReportSeverity sev; + const char *message; +} ReportSpan; + +VECTOR_IMPL(ReportSpan, ReportSpanVec, report_span); + +static inline __attribute__((always_inline)) Location location(uint32_t line, uint32_t column, uint32_t offset) { + return (Location){.line = line, .column = column, .offset = offset}; +} + +// Initialize source from a string and its length (without null terminator), the string will be copied. +Source source_init(const char *str, uint32_t len); +// Try to initialize source from a FILE* +SourceError source_from_file(FILE *f, Source *src); +// Try to initialize source +SourceError source_open(const char *path, Source *src); +// Destroy source +void source_drop(Source src); +void source_report( + const Source *src, + Location loc, + ReportSeverity sev, + const ReportSpan *pspans, + uint32_t span_count, + const char *help, + const char *fmt, + ... +); + +static inline Span span_from_to(Location from, Location to) { + return (Span){ + .loc = from, + .len = to.offset - from.offset, + }; +} +#endif diff --git a/ser/utils.c b/ser/utils.c new file mode 100644 index 0000000..640ec2f --- /dev/null +++ b/ser/utils.c @@ -0,0 +1,147 @@ +#include "utils.h" + +#include "vector.h" + +#include +#include + +bool string_slice_equal(const void *_a, const void *_b) { + const StringSlice *a = (StringSlice *)_a; + const StringSlice *b = (StringSlice *)_b; + if (a->len != b->len) { + return false; + } + uint32_t len = a->len < b->len ? a->len : b->len; + return strncmp(a->ptr, b->ptr, len) == 0; +} + +uint32_t string_slice_hash(Hasher state, const void *_item) { + const StringSlice *item = (StringSlice *)_item; + return hash(state, (byte *)item->ptr, item->len); +} + +bool pointer_equal(const void *_a, const void *_b) { + const void *a = *(void **)_a; + const void *b = *(void **)_b; + return a == b; +} +uint32_t pointer_hash(Hasher state, const void *item) { return hash(state, (byte *)item, sizeof(void *)); } + +StyledString styled_string_init() { + return (StyledString){ + .chars = vec_init(), + .styles = vec_init(), + }; +} + +char *msprintf(const char *fmt, ...) { + va_list args; + va_start(args, fmt); + int len = vsnprintf(NULL, 0, fmt, args); + va_end(args); + char *res = malloc(len + 1); + assert_alloc(res); + va_start(args, fmt); + vsnprintf(res, len + 1, fmt, args); + va_end(args); + return res; +} + +void styled_string_set(StyledString *str, size_t index, const char *style, const char *s, size_t len) { + if (index > str->chars.len) { + vec_fill_range(&str->chars, str->chars.len, index, ' '); + vec_fill_range(&str->styles, str->styles.len, index, NULL); + } + vec_set_array(&str->chars, index, s, len); + vec_fill_range(&str->styles, index, index + len, NULL); + str->styles.data[index] = style; + // Reset the style if there are characters after + if (style != NULL && str->chars.len > index + len) { + str->styles.data[index + len] = "\033[0m"; + } +} + +void styled_string_set_style(StyledString *str, size_t index, const char *style, size_t len) { + if (index > str->chars.len) { + vec_fill_range(&str->chars, str->chars.len, index, ' '); + vec_fill_range(&str->styles, str->styles.len, index, NULL); + } + if (index + len > str->chars.len) { + vec_fill_range(&str->chars, str->chars.len, index + len, ' '); + } + vec_fill_range(&str->styles, index, index + len, NULL); + str->styles.data[index] = style; + // Reset the style if there are characters after + if (style != NULL && str->chars.len > index + len && str->styles.data[index + len] == NULL) { + str->styles.data[index + len] = "\033[0m"; + } +} + +void styled_string_clear(StyledString *str) { + vec_clear(&str->chars); + vec_clear(&str->styles); +} + +void styled_string_fill(StyledString *str, size_t index, char fill, size_t len) { + vec_fill_range(&str->chars, index, index + len, fill); +} + +void styled_string_push(StyledString *str, const char *style, const char *s) { + size_t len = strlen(s); + vec_push_array(&str->chars, s, len); + size_t index = str->styles.len; + vec_fill_range(&str->styles, index, str->chars.len, NULL); + str->styles.data[index] = style; +} + +char *styled_string_build(StyledString *str) { + CharVec res = vec_init(); + vec_grow(&res, str->chars.len + 1); + for (size_t i = 0; i < str->chars.len; i++) { + const char *style = str->styles.data[i]; + if (style != NULL) { + int len = strlen(style); + vec_push_array(&res, style, len); + } + vec_push(&res, str->chars.data[i]); + } + vec_push_array(&res, "\033[0m\0", 5); + return res.data; +} + +size_t styled_string_available_space(StyledString *str, size_t from, size_t stop_at) { + // We always have more space past the end of the string + if (from >= str->chars.len) + return stop_at; + size_t space = 0; + char *c = &str->chars.data[from]; + char *end = &str->chars.data[str->chars.len]; + while (space < stop_at && c < end && *c == ' ') { + space++; + c++; // Blasphemy + } + if (c == end || space == stop_at) { + // We either found enough space, or we got the end of the string (infinite space) + return stop_at; + } + return space; +} + +void styled_string_drop(StyledString str) { + vec_drop(str.styles); + vec_drop(str.chars); +} + +void charvec_push_str(CharVec *v, const char *str) { vec_push_array(v, str, strlen(str)); } + +void charvec_format(CharVec *v, const char *fmt, ...) { + va_list args; + va_start(args, fmt); + int len = vsnprintf(NULL, 0, fmt, args); + va_end(args); + vec_grow(v, v->len + len + 1); + va_start(args, fmt); + vsnprintf(&v->data[v->len], len + 1, fmt, args); + va_end(args); + v->len += len; +} diff --git a/ser/utils.h b/ser/utils.h new file mode 100644 index 0000000..baaca65 --- /dev/null +++ b/ser/utils.h @@ -0,0 +1,60 @@ +#ifndef UTILS_H +#define UTILS_H + +#ifdef NDEBUG +#define IF_DEBUG(...) +#else +#define IF_DEBUG(...) __VA_ARGS__ +#endif + +#include +#include + +typedef unsigned char byte; + +#include "hashmap.h" +#include "vector_impl.h" + +#define STRING_SLICE(str) ((StringSlice){.ptr = str, .len = sizeof(str) - 1}) + +typedef struct { + const char *ptr; + uint32_t len; +} StringSlice; + +bool string_slice_equal(const void *a, const void *b); +uint32_t string_slice_hash(Hasher state, const void *item); + +bool pointer_equal(const void *a, const void *b); +uint32_t pointer_hash(Hasher state, const void *item); + +VECTOR_IMPL(void *, PointerVec, pointer); +VECTOR_IMPL(const char *, ConstStringVec, const_string); +VECTOR_IMPL(char, CharVec, char); +VECTOR_IMPL(uint64_t, UInt64Vec, uint64); +VECTOR_IMPL(CharVec, CharVec2, char_vec, _vec_char_drop); +VECTOR_IMPL(StringSlice, StringSliceVec, string_slice); + +// Styled strings are very mutable strings +typedef struct { + CharVec chars; + ConstStringVec styles; +} StyledString; + +StyledString styled_string_init(); +void styled_string_clear(StyledString *str); +void styled_string_set(StyledString *str, size_t index, const char *style, const char *s, size_t len); +void styled_string_set_style(StyledString *str, size_t index, const char *style, size_t len); +void styled_string_fill(StyledString *str, size_t index, char fill, size_t len); +size_t styled_string_available_space(StyledString *str, size_t from, size_t stop_at); +void styled_string_push(StyledString *str, const char *style, const char *s); +char *styled_string_build(StyledString *str); +void styled_string_drop(StyledString str); +// Printf to an allocated string +char *msprintf(const char *fmt, ...); +void charvec_push_str(CharVec *v, const char *str); +void charvec_format(CharVec *v, const char *fmt, ...); + +VECTOR_IMPL(StyledString, StyledStringVec, styled_string, styled_string_drop); + +#endif diff --git a/ser/vector.h b/ser/vector.h new file mode 100644 index 0000000..cda21e2 --- /dev/null +++ b/ser/vector.h @@ -0,0 +1,30 @@ +#ifndef VECTOR_H +#define VECTOR_H +#include "arena_allocator.h" +#include "ast.h" +#include "codegen.h" +#include "eval.h" +#include "lexer.h" +#include "parser.h" +#include "source.h" +#include "utils.h" + +// This files contains the generic vector macro, which are generated according the VECTOR_IMPL_LIST + +// clang-format: off +#define VECTOR_IMPL_LIST \ + (Token, TokenVec, token, token_drop), (LexingError, LexingErrorVec, lexing_error, lexing_error_drop), \ + (AstItem, AstItemVec, ast_item), (AstField, AstFieldVec, ast_field), \ + (AstAttributeOrMessage, AstAttributeOrMessageVec, ast_attribute_or_message), \ + (ArenaBlock, ArenaBlockVec, arena_block, arena_block_drop), (ParsingError, ParsingErrorVec, parsing_error), \ + (Field, FieldVec, field, field_drop), (EvalError, EvalErrorVec, eval_error), \ + (const char *, ConstStringVec, const_string), (StringSlice, StringSliceVec, string_slice), (char, CharVec, char), \ + (CharVec, CharVec2, char_vec, _vec_char_drop), (ReportSpan, ReportSpanVec, report_span), \ + (StyledString, StyledStringVec, styled_string, styled_string_drop), (Span, SpanVec, span), \ + (void *, PointerVec, pointer), (SpannedStringSlice, SpannedStringSliceVec, spanned_string_slice), \ + (MessageObject, MessageObjectVec, message_object, message_drop), \ + (MessagesObject, MessagesObjectVec, messages_object, messages_drop), (uint64_t, UInt64Vec, uint64), \ + (FieldAccessor, FieldAccessorVec, field_accessor, field_accessor_drop) +#include "vector_impl.h" +// clang-format: on +#endif diff --git a/ser/vector_impl.h b/ser/vector_impl.h new file mode 100644 index 0000000..1b849e9 --- /dev/null +++ b/ser/vector_impl.h @@ -0,0 +1,226 @@ +#ifndef VECTOR_IMPL_H +#define VECTOR_IMPL_H +#include "assert.h" +#include "macro_utils.h" + +#include +#include +#include +#include + +#define _VECTOR_MAP_ADD(m, a, fst, ...) m(a, fst) __VA_OPT__(DEFER1(_VECTOR__MAP_ADD)()(m, a, __VA_ARGS__)) +#define _VECTOR__MAP_ADD() _VECTOR_MAP_ADD + +#define VECTOR_IMPL(T, V, qualifier, ...) \ + typedef struct { \ + T *data; \ + size_t len; \ + size_t cap; \ + } V; \ + __attribute__((unused)) static inline V _vec_##qualifier##_init() { return (V){.data = NULL, .len = 0, .cap = 0}; } \ + __attribute__((unused)) static inline void _vec_##qualifier##_drop(V vec) { \ + __VA_OPT__({ \ + for (size_t i = 0; i < vec.len; i++) { \ + __VA_ARGS__(vec.data[i]); \ + } \ + }) \ + if (vec.data != NULL) { \ + free(vec.data); \ + } \ + } \ + __attribute__((unused)) static inline void _vec_##qualifier##_grow(V *vec, size_t cap) { \ + if (cap <= vec->cap) \ + return; \ + if (vec->data == NULL || vec->cap == 0) { \ + vec->data = malloc(cap * sizeof(T)); \ + assert_alloc(vec->data); \ + vec->cap = cap; \ + return; \ + } \ + if (cap < 2 * vec->cap) { \ + cap = 2 * vec->cap; \ + } \ + if (cap < 4) { \ + cap = 4; \ + } \ + T *newp = realloc(vec->data, cap * sizeof(T)); \ + assert_alloc(newp); \ + vec->data = newp; \ + vec->cap = cap; \ + } \ + __attribute__((unused)) static inline V _vec_##qualifier##_init_with_cap(size_t cap) { \ + V vec = {.data = NULL, .len = 0, .cap = 0}; \ + _vec_##qualifier##_grow(&vec, cap); \ + return vec; \ + } \ + __attribute__((unused)) static inline void _vec_##qualifier##_shrink(V *vec) { \ + if (vec->len > 0) { \ + T *newp = realloc(vec->data, vec->len); \ + assert_alloc(newp); \ + vec->data = newp; \ + vec->cap = vec->len; \ + } else { \ + free(vec->data); \ + vec->data = NULL; \ + vec->cap = 0; \ + } \ + } \ + __attribute__((unused)) static inline void _vec_##qualifier##_push(V *vec, T val) { \ + _vec_##qualifier##_grow(vec, vec->len + 1); \ + vec->data[vec->len++] = val; \ + } \ + __attribute__((unused)) static inline void _vec_##qualifier##_push_array(V *vec, T const *vals, size_t count) { \ + _vec_##qualifier##_grow(vec, vec->len + count); \ + for (size_t i = 0; i < count; i++) { \ + vec->data[vec->len++] = vals[i]; \ + } \ + } \ + __attribute__((unused)) static inline V _vec_##qualifier##_clone(V *vec) { \ + if (vec->len == 0) { \ + return (V){.data = NULL, .len = 0, .cap = 0}; \ + } \ + V res = {.data = NULL, .len = 0, .cap = 0}; \ + _vec_##qualifier##_grow(&res, vec->len); \ + _vec_##qualifier##_push_array(&res, vec->data, vec->len); \ + return res; \ + } \ + __attribute__((unused)) static inline bool _vec_##qualifier##_pop_opt(V *vec, T *val) { \ + if (vec->len == 0) \ + return false; \ + vec->len--; \ + if (val != NULL) { \ + *val = vec->data[vec->len]; \ + } \ + __VA_OPT__(else { __VA_ARGS__(vec->data[vec->len]); }) \ + return true; \ + } \ + __attribute__((unused)) static inline T _vec_##qualifier##_pop(V *vec) { \ + debug_assert(vec->len > 0, "Popping zero length %s", #V); \ + return vec->data[--vec->len]; \ + } \ + __attribute__((unused)) static inline void _vec_##qualifier##_clear(V *vec) { \ + __VA_OPT__({ \ + for (size_t i = 0; i < vec->len; i++) { \ + __VA_ARGS__(vec->data[i]); \ + } \ + }) \ + vec->len = 0; \ + } \ + __attribute__((unused)) static inline T _vec_##qualifier##_get(V *vec, size_t index) { \ + debug_assert(index < vec->len, "Out of bound index, on %s (index is %lu, but length is %lu)", #V, index, vec->len); \ + return vec->data[index]; \ + } \ + __attribute__((unused)) static inline bool _vec_##qualifier##_get_opt(V *vec, size_t index, T *val) { \ + if (index >= vec->len) { \ + return false; \ + } else if (val != NULL) { \ + *val = vec->data[index]; \ + } \ + return true; \ + } \ + __attribute__((unused)) static inline T _vec_##qualifier##_take(V *vec, size_t index) { \ + debug_assert(index < vec->len, "Out of bound index, on %s (index is %lu but length is %lu)", #V, index, vec->len); \ + T res = vec->data[index]; \ + if (index != vec->len - 1) \ + memmove(&vec->data[index], &vec->data[index + 1], (vec->len - index) * sizeof(T)); \ + vec->len--; \ + return res; \ + } \ + __attribute__((unused)) static inline void _vec_##qualifier##_fill_range(V *vec, size_t from, size_t to, T item) { \ + debug_assert(from <= vec->len, "Can't start fill past the end of a vector"); \ + _vec_##qualifier##_grow(vec, to); \ + for (size_t i = from; i < to; i++) { \ + vec->data[i] = item; \ + } \ + vec->len = vec->len > to ? vec->len : to; \ + } \ + __attribute__((unused)) static inline void _vec_##qualifier##_fill(V *vec, T item) { \ + _vec_##qualifier##_fill_range(vec, 0, vec->len, item); \ + } \ + __attribute__((unused)) static inline void _vec_##qualifier##_insert(V *vec, size_t index, T val) { \ + debug_assert(index <= vec->len, "Can't insert past the end of vector"); \ + _vec_##qualifier##_grow(vec, vec->len + 1); \ + if (index < vec->len) { \ + memmove(&vec->data[index + 1], &vec->data[index], (vec->len - index) * sizeof(T)); \ + } \ + vec->data[index] = val; \ + vec->len++; \ + } \ + __attribute__((unused) \ + ) static inline void _vec_##qualifier##_insert_array(V *vec, size_t index, T const *vals, size_t count) { \ + debug_assert(index <= vec->len, "Can't insert past the end of vector"); \ + _vec_##qualifier##_grow(vec, vec->len + count); \ + if (index < vec->len) { \ + memmove(&vec->data[index + count], &vec->data[index], (vec->len - index) * sizeof(T)); \ + } \ + for (size_t i = 0; i < count; i++) { \ + vec->data[index + i] = vals[i]; \ + } \ + vec->len += count; \ + } \ + __attribute__((unused)) static inline void _vec_##qualifier##_set_array(V *vec, size_t index, T const *vals, size_t count) { \ + debug_assert(index <= vec->len, "Can't start set past the end of vector"); \ + _vec_##qualifier##_grow(vec, index + count); \ + for (size_t i = 0; i < count; i++) { \ + vec->data[index + i] = vals[i]; \ + } \ + vec->len = vec->len > (index + count) ? vec->len : (index + count); \ + } \ + __attribute__((unused)) static inline void _vec_##qualifier##_splice(V *vec, size_t index, size_t count) { \ + debug_assert(index < vec->len, "Can't splice past end of vector"); \ + if (count == 0) \ + return; \ + if (index + count < vec->len) { \ + memmove(&vec->data[index], &vec->data[index + count], (vec->len - index - count) * sizeof(T)); \ + } \ + vec->len -= count; \ + } \ + _Static_assert(1, "Semicolon required") + +typedef struct { + void *data; + size_t len; + size_t cap; +} AnyVec; + +#endif + +#ifdef VECTOR_IMPL_LIST + +#define _VECTOR_GEN(a, x) DEFER1(_VECTOR__GEN)(a, _VECTOR__GEN_CLOSE x +#define _VECTOR__GEN_CLOSE(a, b, c, ...) a, b, c __VA_OPT__(,) __VA_ARGS__) +#define _VECTOR__GEN(x, _, V, qualifier, ...) \ + V: \ + _vec_##qualifier##x, + +#define _VECTOR_GENERIC(expr, x) _Generic(expr, EVAL(CALL(_VECTOR_MAP_ADD, _VECTOR_GEN, _##x, VECTOR_IMPL_LIST)) AnyVec: (void)0) + +#define vec_init() \ + { .data = NULL, .len = 0, .cap = 0 } +#define vec_drop(vec) _VECTOR_GENERIC(vec, drop)(vec) +#define vec_grow(vec, len) _VECTOR_GENERIC(*(vec), grow)(vec, len) +#define vec_shrink(vec) _VECTOR_GENERIC(*(vec), shrink)(vec) +#define vec_push(vec, val) _VECTOR_GENERIC(*(vec), push)(vec, val) +#define vec_push_array(vec, vals, count) _VECTOR_GENERIC(*(vec), push_array)(vec, vals, count) +#define vec_clone(vec) _VECTOR_GENERIC(*(vec), clone)(vec) +#define vec_pop_opt(vec, val) _VECTOR_GENERIC(*(vec), pop_opt)(vec, val) +#define vec_pop(vec) _VECTOR_GENERIC(*(vec), pop)(vec) +#define vec_clear(vec) _VECTOR_GENERIC(*(vec), clear)(vec) +#define vec_get(vec, index) _VECTOR_GENERIC(*(vec), get)(vec, index) +#define vec_get_opt(vec, index, val) _VECTOR_GENERIC(*(vec), get_opt)(vec, index, val) +#define vec_take(vec, index) _VECTOR_GENERIC(*(vec), take)(vec, index) +#define vec_fill_range(vec, from, to, item) _VECTOR_GENERIC(*(vec), fill_range)(vec, from, to, item) +#define vec_fill(vec, item) _VECTOR_GENERIC(*(vec), fill)(vec, item) +#define vec_insert(vec, index, val) _VECTOR_GENERIC(*(vec), insert)(vec, index, val) +#define vec_insert_array(vec, index, vals, count) _VECTOR_GENERIC(*(vec), insert_array)(vec, index, vals, count) +#define vec_set_array(vec, index, vals, count) _VECTOR_GENERIC(*(vec), set_array)(vec, index, vals, count) +#define vec_splice(vec, index, count) _VECTOR_GENERIC(*(vec), splice)(vec, index, count) +#define vec_foreach(vec, var, expr) \ + do { \ + for (size_t _i = 0; _i < (vec)->len; _i++) { \ + typeof(*(vec)->data) var = (vec)->data[_i]; \ + expr; \ + } \ + } while (false) + +#endif diff --git a/server.c b/server.c index 8b0e689..8307d06 100644 --- a/server.c +++ b/server.c @@ -118,6 +118,33 @@ static void print_config() { printf("\n"); } +void print_message_buffer(const uint8_t *buf, int len) { + bool last_beg = false; + for (int i = 0; i < len; i++) { + if (i + MSG_MAGIC_SIZE <= len) { + MsgMagic magic = *(MsgMagic *)(&buf[i]); + if (magic == MSG_MAGIC_START) { + printf(" \033[32m%08lX\033[0m", magic); + i += MSG_MAGIC_SIZE - 1; + last_beg = true; + continue; + } else if (magic == MSG_MAGIC_END) { + printf(" \033[32m%08lX\033[0m", magic); + i += MSG_MAGIC_SIZE - 1; + continue; + } + } + + if (last_beg) { + last_beg = false; + printf(" \033[034m%04X\033[0m", *(uint16_t*)&buf[i]); + i++; + } else { + printf(" %02X", buf[i]); + } + } +} + void device_thread_exit(int _sig) { struct DeviceThreadArgs *args = pthread_getspecific(device_args_key); printf("CONN(%d): [%d] exiting\n", args->conn->id, args->index); @@ -145,8 +172,8 @@ void *device_thread(void *args_) { TRAP_IGN(SIGPIPE); TRAP(SIGTERM, device_thread_exit); - uint8_t buf[2048] __attribute__((aligned(4))) = {0}; - MessageDeviceInfo dev_info; + uint8_t buf[2048] __attribute__((aligned(4))) = {0}; + DeviceInfo dev_info; while (true) { if (*args->controller != NULL) { @@ -169,21 +196,21 @@ void *device_thread(void *args_) { // Send over device info { - int len = msg_serialize(buf, 2048, (Message *)&dev_info); + int len = msg_device_serialize(buf, 2048, (DeviceMessage *)&dev_info); if (write(args->conn->socket, buf, len) == -1) { printf("CONN(%d): [%d] Couldn't send device info\n", args->conn->id, args->index); break; } } - MessageDeviceReport report = {0}; + DeviceReport report = {0}; - report.code = DeviceReport; - report.abs_count = ctr->dev.device_info.abs_count; - report.rel_count = ctr->dev.device_info.rel_count; - report.key_count = ctr->dev.device_info.key_count; - report.slot = args->index; - report.index = controller_index; + report.tag = DeviceTagReport; + report.abs.len = ctr->dev.device_info.abs.len; + report.rel.len = ctr->dev.device_info.rel.len; + report.key.len = ctr->dev.device_info.key.len; + report.slot = args->index; + report.index = controller_index; while (true) { struct input_event event; @@ -203,13 +230,12 @@ void *device_thread(void *args_) { } if (event.type == EV_SYN) { - int len = msg_serialize(buf, 2048, (Message *)&report); + int len = msg_device_serialize(buf, 2048, (DeviceMessage *)&report); if (len < 0) { printf("CONN(%d): [%d] Couldn't serialize report %d\n", args->conn->id, args->index, len); continue; }; - send(args->conn->socket, buf, len, 0); } else if (event.type == EV_ABS) { int index = ctr->dev.mapping.abs_indices[event.code]; @@ -219,7 +245,7 @@ void *device_thread(void *args_) { continue; }; - report.abs[index] = event.value; + report.abs.data[index] = event.value; } else if (event.type == EV_REL) { int index = ctr->dev.mapping.rel_indices[event.code]; @@ -228,7 +254,7 @@ void *device_thread(void *args_) { continue; }; - report.rel[index] = event.value; + report.rel.data[index] = event.value; } else if (event.type == EV_KEY) { int index = ctr->dev.mapping.key_indices[event.code]; @@ -236,17 +262,17 @@ void *device_thread(void *args_) { printf("CONN(%d): [%d] Invalid key\n", args->conn->id, args->index); continue; }; - report.key[index] = !!event.value; + report.key.data[index] = !!event.value; } } // Send device destroy message { - MessageDestroy dstr; - dstr.code = DeviceDestroy; + DeviceDestroy dstr; + dstr.tag = DeviceTagDestroy; dstr.index = args->index; - int len = msg_serialize(buf, 2048, (Message *)&dstr); + int len = msg_device_serialize(buf, 2048, (DeviceMessage *)&dstr); if (write(args->conn->socket, buf, len) == -1) { printf("CONN(%d): [%d] Couldn't send device destroy message\n", args->conn->id, args->index); break; @@ -310,7 +336,7 @@ void *server_handle_conn(void *args_) { if (len <= 0) { closing_message = "Lost peer (from recv)"; goto conn_end; - } else if (len > 1 + MAGIC_SIZE * 2) { + } else if (len > 1) { printf("CONN(%d): Got message: ", args->id); printf("\n"); } else { @@ -318,10 +344,10 @@ void *server_handle_conn(void *args_) { } // Parse message - Message msg; - int msg_len = msg_deserialize(buf, len, &msg); + DeviceMessage msg; + int msg_len = msg_device_deserialize(buf, len, &msg); if (msg_len < 0) { - if (len > 1 + MAGIC_SIZE * 2) { + if (len > 1) { printf("CONN(%d): Couldn't parse message: ", args->id); print_message_buffer(buf, len); printf("\n"); @@ -332,7 +358,7 @@ void *server_handle_conn(void *args_) { } // Handle message - if (msg.code == ControllerState) { + if (msg.tag == DeviceTagControllerState) { int i = msg.controller_state.index; if (i >= device_controllers.len) { printf("CONN(%d): Invalid controller index in controller state message\n", args->id); @@ -346,10 +372,10 @@ void *server_handle_conn(void *args_) { } apply_controller_state(ctr, &msg.controller_state); - } else if (msg.code == Request) { + } else if (msg.tag == DeviceTagRequest) { if (got_request) { printf("CONN(%d): Illegal Request message after initial request\n", args->id); - msg_free(&msg); + msg_device_free(&msg); continue; } @@ -357,7 +383,7 @@ void *server_handle_conn(void *args_) { printf("CONN(%d): Got client request\n", args->id); - for (int i = 0; i < msg.request.request_count; i++) { + for (int i = 0; i < msg.request.requests.len; i++) { int index = device_controllers.len; Controller *ctr = NULL; vec_push(&device_controllers, &ctr); @@ -365,13 +391,14 @@ void *server_handle_conn(void *args_) { struct DeviceThreadArgs *dev_args = malloc(sizeof(struct DeviceThreadArgs)); dev_args->controller = vec_get(&device_controllers, index); - dev_args->tag_count = msg.request.requests[i].count; + dev_args->tag_count = msg.request.requests.data[i].tags.len; dev_args->tags = malloc(dev_args->tag_count * sizeof(char *)); dev_args->conn = args; dev_args->index = index; for (int j = 0; j < dev_args->tag_count; j++) { - dev_args->tags[j] = strdup(msg.request.requests[i].tags[j]); + Tag t = msg.request.requests.data[i].tags.data[j]; + dev_args->tags[j] = strndup(t.name.data, t.name.len); } pthread_t thread; @@ -379,7 +406,7 @@ void *server_handle_conn(void *args_) { vec_push(&device_threads, &thread); } - msg_free(&msg); + msg_device_free(&msg); } else { printf("CONN(%d): Illegal message\n", args->id); }