Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions crates/tako/src/internal/server/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,6 @@ impl Core {
(&mut self.tasks, &mut self.data_objects)
}

#[cfg(test)]
pub fn get_resource_map_mut(&mut self) -> &mut GlobalResourceMapping {
&mut self.resource_map
}

pub fn new_worker_id(&mut self) -> WorkerId {
self.worker_id_counter += 1;
WorkerId::new(self.worker_id_counter)
Expand Down Expand Up @@ -619,7 +614,9 @@ mod tests {
use crate::internal::server::task::TaskRuntimeState;
use crate::internal::server::worker::Worker;
use crate::internal::server::workergroup::WorkerGroup;
use crate::internal::tests::utils::task;

use crate::tests::utils::env::TestEnv;

use crate::{TaskId, WorkerId};

impl Core {
Expand Down Expand Up @@ -700,15 +697,13 @@ mod tests {

#[test]
fn add_remove() {
let mut core = Core::default();
let rmap = core.get_resource_map_mut();
let t = task::task(101, rmap);
core.add_task(t);
let mut rt = TestEnv::new();
rt.new_task_default(101);
let mut objs_to_remove = ObjsToRemoveFromWorkers::new();
assert!(matches!(
core.remove_task(101.into(), &mut objs_to_remove),
rt.core().remove_task(101.into(), &mut objs_to_remove),
TaskRuntimeState::Waiting(_)
));
assert_eq!(core.find_task(101.into()), None);
assert_eq!(rt.core().find_task(101.into()), None);
}
}
134 changes: 66 additions & 68 deletions crates/tako/src/internal/server/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,96 +144,91 @@ pub fn task_explain_for_worker(

#[cfg(test)]
mod tests {
use crate::internal::common::resources::map::GlobalResourceMapping;

use crate::internal::server::explain::{TaskExplainItem, task_explain_for_worker};
use crate::internal::server::worker::Worker;

use crate::internal::server::workergroup::WorkerGroup;
use crate::internal::tests::utils::schedule::create_test_worker_config;
use crate::internal::tests::utils::task::TaskBuilder;
use crate::resources::{
ResourceAmount, ResourceDescriptor, ResourceDescriptorItem, ResourceIdMap,
};
use crate::{Set, WorkerId};

use crate::resources::ResourceAmount;
use crate::tests::utils::env::TestEnv;
use crate::tests::utils::worker::WorkerBuilder;
use crate::{Set, TaskId, WorkerId};
use std::time::{Duration, Instant};

#[test]
fn explain_single_node() {
let mut rqs = GlobalResourceMapping::default();
let resource_map = ResourceIdMap::from_vec(vec!["cpus".to_string(), "gpus".to_string()]);
let now = Instant::now();

let wcfg = create_test_worker_config(1.into(), ResourceDescriptor::simple_cpus(4));
let worker1 = Worker::new(1.into(), wcfg, &resource_map, now);

let mut wcfg = create_test_worker_config(
2.into(),
ResourceDescriptor::new(
vec![
ResourceDescriptorItem::range("cpus", 1, 10),
ResourceDescriptorItem::range("gpus", 1, 4),
],
Default::default(),
),
let mut rt = TestEnv::new();
rt.new_named_resource("gpus");
rt.new_worker_with_id(1, &WorkerBuilder::new(4));
rt.new_worker_with_id(
2,
&WorkerBuilder::empty()
.res_range("cpus", 1, 10)
.res_range("gpus", 1, 4)
.time_limit(Duration::from_secs(40_000)),
);
wcfg.time_limit = Some(Duration::from_secs(40_000));
let worker2 = Worker::new(2.into(), wcfg, &resource_map, now);

let explain = |task, rqs: &GlobalResourceMapping, worker, now| {
let resource_map = rt.core().create_resource_map();
let explain = |rt: &mut TestEnv, task, worker, now| {
let group = WorkerGroup::new(Set::new());
let (task_map, worker_map, rqs) = rt.core().split_tasks_workers_requests_mut();
task_explain_for_worker(
&resource_map,
rqs.get_resource_rq_map(),
task,
worker,
rqs,
task_map.get_task(TaskId::new_test(task)),
worker_map.get_worker(WorkerId::new(worker)),
&group,
now,
)
};

let task_id = 1;
let task = TaskBuilder::new(task_id).build(&mut rqs);
let r = explain(&task, &rqs, &worker1, now);
let _rqs = rt.core().resource_map_mut();
let now = Instant::now();
rt.new_task(1, &TaskBuilder::new());
let (_task_map, _worker_map, _rqs) = rt.core().split_tasks_workers_requests_mut();
let r = explain(&mut rt, 1, 1, now);
assert_eq!(r.variants.len(), 1);
assert_eq!(r.variants[0].len(), 1);
assert_eq!(r.n_enabled_variants(), 1);

let task = TaskBuilder::new(task_id)
.time_request(20_000)
.build(&mut rqs);
let r = explain(&task, &rqs, &worker1, now);
rt.new_task(2, &TaskBuilder::new().time_request(20_000));
let r = explain(&mut rt, 2, 1, now);
assert_eq!(r.variants.len(), 1);
assert_eq!(r.variants[0].len(), 2);
assert_eq!(r.n_enabled_variants(), 1);

let r = explain(&task, &rqs, &worker2, now);
let r = explain(&mut rt, 2, 2, now);
assert_eq!(r.variants.len(), 1);
assert_eq!(r.variants[0].len(), 2);
assert_eq!(r.n_enabled_variants(), 1);

let now2 = now + Duration::from_secs(21_000);
let r = explain(&task, &rqs, &worker1, now2);
let r = explain(&mut rt, 2, 1, now2);
assert_eq!(r.variants.len(), 1);
assert_eq!(r.variants[0].len(), 2);
assert_eq!(r.n_enabled_variants(), 1);

let r = explain(&task, &rqs, &worker2, now2);
let r = explain(&mut rt, 2, 2, now2);
assert_eq!(r.variants.len(), 1);
assert_eq!(r.variants[0].len(), 2);
assert!(matches!(
r.variants[0][0],
TaskExplainItem::Time {
min_time,
remaining_time,
} if min_time == Duration::from_secs(20_000) && remaining_time == Some(Duration::from_secs(19_000))
} if min_time == Duration::from_secs(20_000) && (remaining_time.unwrap().as_secs().abs_diff(19_000) < 3)
));
assert_eq!(r.n_enabled_variants(), 0);

let task = TaskBuilder::new(task_id)
.time_request(20_000)
.cpus_compact(30)
.add_resource(1, 3)
.build(&mut rqs);
let r = explain(&task, &rqs, &worker2, now);
rt.new_task(
3,
&TaskBuilder::new()
.time_request(20_000)
.cpus(30)
.add_resource(1, 3),
);
let r = explain(&mut rt, 3, 2, now);
assert_eq!(r.variants.len(), 1);
assert_eq!(r.variants[0].len(), 3);
assert!(matches!(
Expand All @@ -244,25 +239,27 @@ mod tests {
));
assert_eq!(r.n_enabled_variants(), 0);

let task = TaskBuilder::new(task_id)
.time_request(30_000)
.cpus_compact(15)
.add_resource(1, 8)
.next_resources()
.cpus_compact(2)
.add_resource(1, 32)
.build(&mut rqs);
let r = explain(&task, &rqs, &worker2, now2);
rt.new_task(
4,
&TaskBuilder::new()
.time_request(30_000)
.cpus(15)
.add_resource(1, 8)
.next_resources()
.cpus(2)
.add_resource(1, 32),
);
let r = explain(&mut rt, 4, 2, now2);
assert_eq!(r.variants.len(), 2);
assert_eq!(r.variants[0].len(), 3);
assert_eq!(r.variants[1].len(), 2);

assert!(matches!(
r.variants[0][0],
TaskExplainItem::Time {
min_time,
remaining_time,
} if min_time == Duration::from_secs(30_000) && remaining_time == Some(Duration::from_secs(19_000))
));
} if min_time == Duration::from_secs(30_000) && (remaining_time.unwrap().as_secs().abs_diff(19_000) < 3)));
assert!(matches!(
&r.variants[0][1],
TaskExplainItem::Resources {
Expand All @@ -285,24 +282,25 @@ mod tests {

#[test]
fn explain_multi_node() {
let mut rqs = GlobalResourceMapping::default();
let resource_map = ResourceIdMap::from_vec(vec!["cpus".to_string(), "gpus".to_string()]);
let now = Instant::now();
let mut rt = TestEnv::new();

rt.new_worker_with_id(1, &WorkerBuilder::new(4));

let wcfg = create_test_worker_config(1.into(), ResourceDescriptor::simple_cpus(4));
let worker = Worker::new(1.into(), wcfg, &resource_map, now);
let task = TaskBuilder::new(1).n_nodes(4).build(&mut rqs);
let _task = rt.new_task(1, &TaskBuilder::new().n_nodes(4));
let mut wset = Set::new();
wset.insert(WorkerId::new(1));
wset.insert(WorkerId::new(2));
wset.insert(WorkerId::new(3));
wset.insert(WorkerId::new(132));
let group = WorkerGroup::new(wset);
let resource_map = rt.core().create_resource_map();
let (task_map, worker_map, rqs) = rt.core().split_tasks_workers_requests_mut();
let r = task_explain_for_worker(
&resource_map,
rqs.get_resource_rq_map(),
&task,
&worker,
rqs,
task_map.get_task(TaskId::new_test(1)),
worker_map.get_worker(WorkerId::new(1)),
&group,
now,
);
Expand All @@ -316,9 +314,9 @@ mod tests {
let group = WorkerGroup::new(wset);
let r = task_explain_for_worker(
&resource_map,
rqs.get_resource_rq_map(),
&task,
&worker,
rqs,
task_map.get_task(TaskId::new_test(1)),
worker_map.get_worker(WorkerId::new(1)),
&group,
now,
);
Expand Down
39 changes: 18 additions & 21 deletions crates/tako/src/internal/server/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,11 @@ fn estimate_shared_data_size(data: &ComputeTaskSharedData) -> usize {

#[cfg(test)]
mod tests {
use crate::internal::common::resources::map::GlobalResourceMapping;
use crate::internal::server::core::Core;

use crate::internal::server::task::{Task, TaskRuntimeState};
use crate::internal::tests::utils::schedule::submit_test_tasks;
use crate::internal::tests::utils::task;
use crate::internal::tests::utils::task::task_with_deps;

use crate::tests::utils::env::TestEnv;
use crate::tests::utils::task::TaskBuilder;
use std::default::Default;

impl Task {
Expand All @@ -521,29 +520,27 @@ mod tests {

#[test]
fn task_consumers_empty() {
let mut rmap = GlobalResourceMapping::default();
let a = task::task(0, &mut rmap);
let mut rt = TestEnv::new();
let a = rt.new_task_default(0);
let mut s = crate::Set::new();
a.collect_recursive_consumers(&Default::default(), &mut s);
rt.task(a)
.collect_recursive_consumers(&Default::default(), &mut s);
assert!(s.is_empty());
}

#[test]
fn task_recursive_consumers() {
let mut core = Core::default();
let rmap = core.get_resource_map_mut();
let a = task::task(0, rmap);
let b = task_with_deps(1, &[&a], rmap);
let c = task_with_deps(2, &[&b], rmap);
let d = task_with_deps(3, &[&b], rmap);
let e = task_with_deps(4, &[&c, &d], rmap);

let expected_ids = vec![b.id, c.id, d.id, e.id];
submit_test_tasks(&mut core, vec![a, b, c, d, e]);

let mut rt = TestEnv::new();
let a = rt.new_task_default(0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😆 I feel like you were reading my slides.

let b = rt.new_task(1, &TaskBuilder::new().task_deps(&[a]));
let c = rt.new_task(2, &TaskBuilder::new().task_deps(&[b]));
let d = rt.new_task(3, &TaskBuilder::new().task_deps(&[b]));
let e = rt.new_task(4, &TaskBuilder::new().task_deps(&[c, d]));

let expected_ids = vec![b, c, d, e];
let mut s = crate::Set::new();
core.get_task(0.into())
.collect_recursive_consumers(core.task_map(), &mut s);
let tasks = rt.task_map();
rt.task(0).collect_recursive_consumers(tasks, &mut s);
assert_eq!(s, expected_ids.into_iter().collect());
}
}
Loading
Loading