sched: enforce ref-counting on current task/thread pointers

This commit is contained in:
2026-04-01 18:17:05 +01:00
parent 15c2207ab9
commit 512356ac2d
28 changed files with 364 additions and 103 deletions
+6 -4
View File
@@ -139,7 +139,6 @@ extern kern_status_t channel_recv_msg(
kern_msg_t *out_msg,
unsigned long *irq_flags)
{
struct thread *self = current_thread();
struct msg *msg = NULL;
unsigned long msg_lock_flags;
@@ -170,7 +169,7 @@ extern kern_status_t channel_recv_msg(
}
struct task *sender = msg->msg_sender_thread->tr_parent;
struct task *receiver = self->tr_parent;
struct task *receiver = get_current_task();
struct address_space *src = sender->t_address_space,
*dst = receiver->t_address_space;
@@ -192,6 +191,7 @@ extern kern_status_t channel_recv_msg(
if (status != KERN_OK) {
kmsg_reply_error(msg, status, &msg_lock_flags);
put_current_task(receiver);
return status;
}
@@ -216,6 +216,7 @@ extern kern_status_t channel_recv_msg(
&receiver->t_handles_lock,
f);
address_space_unlock_pair_irqrestore(src, dst, f);
put_current_task(receiver);
if (status != KERN_OK) {
kmsg_reply_error(msg, status, &msg_lock_flags);
@@ -250,11 +251,10 @@ extern kern_status_t channel_reply_msg(
return KERN_INVALID_ARGUMENT;
}
struct thread *self = current_thread();
/* the task that is about to receive the response */
struct task *receiver = msg->msg_sender_thread->tr_parent;
/* the task that is about to send the response */
struct task *sender = self->tr_parent;
struct task *sender = get_current_task();
struct address_space *src = sender->t_address_space,
*dst = receiver->t_address_space;
@@ -275,6 +275,7 @@ extern kern_status_t channel_reply_msg(
if (status != KERN_OK) {
kmsg_reply_error(msg, status, &msg_lock_flags);
put_current_task(sender);
return status;
}
@@ -299,6 +300,7 @@ extern kern_status_t channel_reply_msg(
&receiver->t_handles_lock,
f);
address_space_unlock_pair_irqrestore(src, dst, f);
put_current_task(sender);
if (status != KERN_OK) {
kmsg_reply_error(msg, status, &msg_lock_flags);
+23 -3
View File
@@ -32,8 +32,10 @@ static kern_status_t get_data(
{
spin_lock_t *lock = NULL;
struct btree *futex_list = NULL;
struct task *self = NULL;
if (flags & FUTEX_PRIVATE) {
struct task *self = current_task();
self = get_current_task();
lock = &self->t_base.ob_lock;
futex_list = &self->t_futex;
} else if (flags & FUTEX_SHARED) {
@@ -48,12 +50,20 @@ static kern_status_t get_data(
if (!futex && !(flags & FUTEX_CREATE)) {
spin_unlock_irqrestore(lock, *irq_flags);
if (self) {
put_current_task(self);
}
return KERN_NO_ENTRY;
}
futex = vm_cache_alloc(&futex_cache, VM_NORMAL);
if (!futex) {
spin_unlock_irqrestore(lock, *irq_flags);
if (self) {
put_current_task(self);
}
return KERN_NO_MEMORY;
}
@@ -61,6 +71,10 @@ static kern_status_t get_data(
put_futex(futex_list, futex);
if (self) {
put_current_task(self);
}
*out = futex;
*out_lock = lock;
return KERN_OK;
@@ -68,9 +82,10 @@ static kern_status_t get_data(
static kern_status_t cleanup_data(struct futex *futex, unsigned int flags)
{
struct task *self = NULL;
struct btree *futex_list = NULL;
if (flags & FUTEX_PRIVATE) {
struct task *self = current_task();
self = get_current_task();
futex_list = &self->t_futex;
} else if (flags & FUTEX_SHARED) {
futex_list = &shared_futex_list;
@@ -81,12 +96,16 @@ static kern_status_t cleanup_data(struct futex *futex, unsigned int flags)
btree_delete(futex_list, &futex->f_node);
vm_cache_free(&futex_cache, futex);
if (self) {
put_current_task(self);
}
return KERN_OK;
}
static kern_status_t futex_get_shared(kern_futex_t *futex, futex_key_t *out)
{
struct task *self = current_task();
struct task *self = get_current_task();
struct address_space *space = self->t_address_space;
unsigned long flags;
@@ -97,6 +116,7 @@ static kern_status_t futex_get_shared(kern_futex_t *futex, futex_key_t *out)
out,
&flags);
address_space_unlock_irqrestore(space, flags);
put_current_task(self);
return status;
}
+2 -1
View File
@@ -175,7 +175,7 @@ void object_wait_signal(
uint32_t signals,
unsigned long *irq_flags)
{
struct thread *self = current_thread();
struct thread *self = get_current_thread();
struct wait_item waiter;
wait_item_init(&waiter, self);
for (;;) {
@@ -189,4 +189,5 @@ void object_wait_signal(
object_lock_irqsave(obj, irq_flags);
}
thread_wait_end(&waiter, &obj->ob_wq);
put_current_thread(self);
}
+5 -2
View File
@@ -19,8 +19,8 @@ void panic_irq(struct ml_cpu_context *ctx, const char *fmt, ...)
printk("---[ kernel panic: %s", buf);
printk("kernel: " BUILD_ID ", compiler version: " __VERSION__);
struct task *task = current_task();
struct thread *thr = current_thread();
struct task *task = get_current_task();
struct thread *thr = get_current_thread();
if (task && thr) {
printk("task: %s (id: %d, thread: %d)",
@@ -31,6 +31,9 @@ void panic_irq(struct ml_cpu_context *ctx, const char *fmt, ...)
printk("task: [bootstrap]");
}
put_current_thread(thr);
put_current_task(task);
printk("cpu: %u", this_cpu());
ml_print_cpu_state(ctx);
+12 -4
View File
@@ -37,7 +37,7 @@ struct port *port_cast(struct object *obj)
static void wait_for_reply(struct msg *msg, unsigned long *lock_flags)
{
struct wait_item waiter;
struct thread *self = current_thread();
struct thread *self = get_current_thread();
wait_item_init(&waiter, self);
for (;;) {
@@ -52,6 +52,7 @@ static void wait_for_reply(struct msg *msg, unsigned long *lock_flags)
}
self->tr_state = THREAD_READY;
put_current_thread(self);
}
struct port *port_create(void)
@@ -83,9 +84,12 @@ kern_status_t port_connect(struct port *port, struct channel *remote)
msg->msg_status = KMSG_ASYNC;
msg->msg_type = KERN_MSG_TYPE_EVENT;
msg->msg_event = KERN_MSG_EVENT_CONNECTION;
msg->msg_sender_thread_id = current_thread()->tr_id;
msg->msg_sender_port_id = port->p_base.ob_id;
struct thread *self = get_current_thread();
msg->msg_sender_thread_id = self->tr_id;
put_current_thread(self);
unsigned long flags;
channel_lock_irqsave(remote, &flags);
channel_enqueue_msg(remote, msg);
@@ -112,9 +116,12 @@ kern_status_t port_disconnect(struct port *port)
msg->msg_status = KMSG_ASYNC;
msg->msg_type = KERN_MSG_TYPE_EVENT;
msg->msg_event = KERN_MSG_EVENT_DISCONNECTION;
msg->msg_sender_thread_id = current_thread()->tr_id;
msg->msg_sender_port_id = port->p_base.ob_id;
struct thread *self = get_current_thread();
msg->msg_sender_thread_id = self->tr_id;
put_current_thread(self);
unsigned long flags;
channel_lock_irqsave(port->p_remote, &flags);
channel_enqueue_msg(port->p_remote, msg);
@@ -136,7 +143,7 @@ kern_status_t port_send_msg(
return KERN_BAD_STATE;
}
struct thread *self = current_thread();
struct thread *self = get_current_thread();
struct msg msg;
memset(&msg, 0x0, sizeof msg);
msg.msg_type = KERN_MSG_TYPE_DATA;
@@ -156,6 +163,7 @@ kern_status_t port_send_msg(
channel_lock_irqsave(port->p_remote, &flags);
btree_delete(&port->p_remote->c_msg, &msg.msg_node);
channel_unlock_irqrestore(port->p_remote, flags);
put_current_thread(self);
return msg.msg_result;
}