diff --git a/BUILD.bazel b/BUILD.bazel
index 7f0884f348aa2..88088be4788d9 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -2735,6 +2735,7 @@ pyx_library(
"//:global_state_accessor_lib",
"//:ray_util",
"//:raylet_lib",
+ "//:redis_client",
"//:src/ray/ray_exported_symbols.lds",
"//:src/ray/ray_version_script.lds",
"//:stats_lib",
diff --git a/ci/env/install-core-prerelease-dependencies.sh b/ci/env/install-core-prerelease-dependencies.sh
index 5efa3ddfe3bff..3182878f9f99f 100755
--- a/ci/env/install-core-prerelease-dependencies.sh
+++ b/ci/env/install-core-prerelease-dependencies.sh
@@ -3,7 +3,8 @@
set -e
# install all unbounded dependencies in setup.py for ray core
-for dependency in attrs jsonschema aiosignal frozenlist requests grpcio protobuf
+# TOOD(scv119) reenable grpcio once https://github.com/grpc/grpc/issues/31885 is fixed.
+for dependency in attrs jsonschema aiosignal frozenlist requests protobuf
do
python -m pip install -U --pre --upgrade-strategy=eager $dependency
done
diff --git a/ci/pipeline/determine_tests_to_run.py b/ci/pipeline/determine_tests_to_run.py
index 635a28d6694a0..7d10e80ad09f8 100644
--- a/ci/pipeline/determine_tests_to_run.py
+++ b/ci/pipeline/determine_tests_to_run.py
@@ -183,6 +183,7 @@ def get_commit_range():
RAY_CI_SERVE_AFFECTED = 1
RAY_CI_LINUX_WHEELS_AFFECTED = 1
RAY_CI_MACOS_WHEELS_AFFECTED = 1
+ RAY_CI_JAVA_AFFECTED = 1
elif changed_file.startswith("python/ray/dashboard"):
RAY_CI_DASHBOARD_AFFECTED = 1
# https://github.com/ray-project/ray/pull/15981
diff --git a/dashboard/client/src/App.tsx b/dashboard/client/src/App.tsx
index c48be1456e9f8..302a9afe3985c 100644
--- a/dashboard/client/src/App.tsx
+++ b/dashboard/client/src/App.tsx
@@ -8,6 +8,7 @@ import Events from "./pages/event/Events";
import Loading from "./pages/exception/Loading";
import JobList, { NewIAJobsPage } from "./pages/job";
import { JobDetailChartsPage } from "./pages/job/JobDetail";
+import { JobDetailActorsPage } from "./pages/job/JobDetailActorPage";
import { JobDetailInfoPage } from "./pages/job/JobDetailInfoPage";
import { JobDetailLayout } from "./pages/job/JobDetailLayout";
import { DEFAULT_VALUE, MainNavContext } from "./pages/layout/mainNavContext";
@@ -205,11 +206,19 @@ const App = () => {
-
+
}
path=""
/>
+
+
+
+ }
+ path="actors"
+ />
} path="logs">
diff --git a/dashboard/client/src/common/CollapsibleSection.tsx b/dashboard/client/src/common/CollapsibleSection.tsx
index 879ce6bf692c0..2f42d82ff7d6e 100644
--- a/dashboard/client/src/common/CollapsibleSection.tsx
+++ b/dashboard/client/src/common/CollapsibleSection.tsx
@@ -1,6 +1,8 @@
import { createStyles, makeStyles, Typography } from "@material-ui/core";
-import React, { PropsWithChildren, useState } from "react";
-import { RiArrowDownSLine, RiArrowUpSLine } from "react-icons/ri";
+import classNames from "classnames";
+import React, { PropsWithChildren, useEffect, useState } from "react";
+import { RiArrowDownSLine, RiArrowRightSLine } from "react-icons/ri";
+import { ClassNameProps } from "./props";
const useStyles = makeStyles((theme) =>
createStyles({
@@ -10,6 +12,7 @@ const useStyles = makeStyles((theme) =>
flexWrap: "nowrap",
alignItems: "center",
fontWeight: 500,
+ cursor: "pointer",
},
icon: {
marginRight: theme.spacing(1),
@@ -17,25 +20,42 @@ const useStyles = makeStyles((theme) =>
height: 24,
},
body: {
- marginTop: theme.spacing(3),
+ marginTop: theme.spacing(1),
+ },
+ bodyHidden: {
+ display: "none",
},
}),
);
-type CollapsibleSectionProps = PropsWithChildren<{
- title: string;
- startExpanded?: boolean;
- className?: string;
-}>;
+type CollapsibleSectionProps = PropsWithChildren<
+ {
+ title: string;
+ startExpanded?: boolean;
+ /**
+ * An optimization to not avoid re-rendering the contents of the collapsible section.
+ * When enabled, we will keep the content around when collapsing but hide it via css.
+ */
+ keepRendered?: boolean;
+ } & ClassNameProps
+>;
export const CollapsibleSection = ({
title,
startExpanded = false,
className,
children,
+ keepRendered,
}: CollapsibleSectionProps) => {
const classes = useStyles();
const [expanded, setExpanded] = useState(startExpanded);
+ const [rendered, setRendered] = useState(expanded);
+
+ useEffect(() => {
+ if (expanded) {
+ setRendered(true);
+ }
+ }, [expanded]);
const handleExpandClick = () => {
setExpanded(!expanded);
@@ -51,11 +71,19 @@ export const CollapsibleSection = ({
{expanded ? (
) : (
-
+
)}
{title}
- {expanded &&
{children}
}
+ {(expanded || (keepRendered && rendered)) && (
+
+ {children}
+
+ )}
);
};
diff --git a/dashboard/client/src/components/TaskTable.tsx b/dashboard/client/src/components/TaskTable.tsx
index e14055fa68a0c..c833cb1a2ae0b 100644
--- a/dashboard/client/src/components/TaskTable.tsx
+++ b/dashboard/client/src/components/TaskTable.tsx
@@ -13,9 +13,11 @@ import {
import Autocomplete from "@material-ui/lab/Autocomplete";
import Pagination from "@material-ui/lab/Pagination";
import React, { useState } from "react";
+import { DurationText } from "../common/DurationText";
import rowStyles from "../common/RowStyles";
import { Task } from "../type/task";
import { useFilter } from "../util/hook";
+import StateCounter from "./StatesCounter";
import { StatusChip } from "./StatusChip";
const TaskTable = ({
@@ -36,7 +38,8 @@ const TaskTable = ({
{ label: "ID" },
{ label: "Name" },
{ label: "Job Id" },
- { label: "Scheduling State" },
+ { label: "State" },
+ { label: "Duration" },
{ label: "Function or Class Name" },
{ label: "Node Id" },
{ label: "Actor_id" },
@@ -59,12 +62,12 @@ const TaskTable = ({
/>
e.scheduling_state)))}
+ options={Array.from(new Set(tasks.map((e) => e.state)))}
onInputChange={(_: any, value: string) => {
- changeFilter("scheduling_state", value.trim());
+ changeFilter("state", value.trim());
}}
renderInput={(params: TextFieldProps) => (
-
+
)}
/>
+
+
+
@@ -140,12 +146,15 @@ const TaskTable = ({
task_id,
name,
job_id,
- scheduling_state,
+ state,
func_or_class_name,
node_id,
actor_id,
type,
required_resources,
+ events,
+ start_time_ms,
+ end_time_ms,
}) => (
@@ -161,7 +170,17 @@ const TaskTable = ({
{name ? name : "-"}
{job_id}
-
+
+
+
+ {start_time_ms && start_time_ms > 0 ? (
+
+ ) : (
+ "-"
+ )}
{func_or_class_name}
{node_id ? node_id : "-"}
diff --git a/dashboard/client/src/pages/job/JobDetail.tsx b/dashboard/client/src/pages/job/JobDetail.tsx
index cddca7c6761c9..8f0b6cdf1c21d 100644
--- a/dashboard/client/src/pages/job/JobDetail.tsx
+++ b/dashboard/client/src/pages/job/JobDetail.tsx
@@ -25,7 +25,13 @@ const useStyle = makeStyles((theme) => ({
},
}));
-export const JobDetailChartsPage = () => {
+type JobDetailChartsPageProps = {
+ newIA?: boolean;
+};
+
+export const JobDetailChartsPage = ({
+ newIA = false,
+}: JobDetailChartsPageProps) => {
const classes = useStyle();
const { job, msg, params } = useJobDetail();
const jobId = params.id;
diff --git a/dashboard/client/src/pages/job/JobDetailActorPage.tsx b/dashboard/client/src/pages/job/JobDetailActorPage.tsx
new file mode 100644
index 0000000000000..c0fa8735850cb
--- /dev/null
+++ b/dashboard/client/src/pages/job/JobDetailActorPage.tsx
@@ -0,0 +1,37 @@
+import { makeStyles } from "@material-ui/core";
+import React from "react";
+
+import TitleCard from "../../components/TitleCard";
+import ActorList from "../actor/ActorList";
+import { MainNavPageInfo } from "../layout/mainNavContext";
+import { useJobDetail } from "./hook/useJobDetail";
+
+const useStyle = makeStyles((theme) => ({
+ root: {
+ padding: theme.spacing(2),
+ },
+}));
+
+export const JobDetailActorsPage = () => {
+ const classes = useStyle();
+ const { job, params } = useJobDetail();
+
+ const pageInfo = job
+ ? {
+ title: "Actors",
+ id: "actors",
+ path: job.job_id ? `/new/jobs/${job.job_id}/actors` : undefined,
+ }
+ : {
+ title: "Actors",
+ id: "actors",
+ path: undefined,
+ };
+
+ return (
+
+ );
+};
diff --git a/dashboard/client/src/pages/job/JobDetailLayout.tsx b/dashboard/client/src/pages/job/JobDetailLayout.tsx
index 3c65ccc6b42a6..f87915e451308 100644
--- a/dashboard/client/src/pages/job/JobDetailLayout.tsx
+++ b/dashboard/client/src/pages/job/JobDetailLayout.tsx
@@ -1,5 +1,9 @@
import React from "react";
-import { RiInformationLine, RiLineChartLine } from "react-icons/ri";
+import {
+ RiGradienterLine,
+ RiInformationLine,
+ RiLineChartLine,
+} from "react-icons/ri";
import { MainNavPageInfo } from "../layout/mainNavContext";
import { SideTabLayout, SideTabRouteLink } from "../layout/SideTabLayout";
import { useJobDetail } from "./hook/useJobDetail";
@@ -29,6 +33,12 @@ export const JobDetailLayout = () => {
title="Charts"
Icon={RiLineChartLine}
/>
+
);
};
diff --git a/dashboard/client/src/pages/layout/MainNavLayout.tsx b/dashboard/client/src/pages/layout/MainNavLayout.tsx
index 6a001a29eb3b7..85dabc6d97a35 100644
--- a/dashboard/client/src/pages/layout/MainNavLayout.tsx
+++ b/dashboard/client/src/pages/layout/MainNavLayout.tsx
@@ -88,13 +88,14 @@ const useMainNavBarStyles = makeStyles((theme) =>
boxShadow: "0px 1px 0px #D2DCE6",
},
logo: {
- width: 60,
display: "flex",
justifyContent: "center",
+ marginLeft: theme.spacing(2),
+ marginRight: theme.spacing(3),
},
navItem: {
- marginRight: theme.spacing(2),
- fontSize: "1em",
+ marginRight: theme.spacing(6),
+ fontSize: "1rem",
fontWeight: 500,
color: "black",
textDecoration: "none",
@@ -211,15 +212,21 @@ const MainNavBreadcrumbs = () => {
);
if (index === 0) {
return (
-
+
{linkOrText}
);
} else {
return (
- {"/"}
-
+
+ {"/"}
+
+
{linkOrText}
diff --git a/dashboard/client/src/pages/log/Logs.tsx b/dashboard/client/src/pages/log/Logs.tsx
index 6501a0c8b7786..7a36a1e23c64c 100644
--- a/dashboard/client/src/pages/log/Logs.tsx
+++ b/dashboard/client/src/pages/log/Logs.tsx
@@ -126,7 +126,7 @@ const Logs = (props: LogsProps) => {
setEnd,
} = useLogs(props);
const { newIA } = props;
- let href = newIA ? "#/new/log/" : "#/log/";
+ let href = newIA ? "#/new/logs/" : "#/log/";
if (origin) {
if (path) {
diff --git a/dashboard/client/src/pages/metrics/Metrics.tsx b/dashboard/client/src/pages/metrics/Metrics.tsx
index 554990b3250ef..8610f6961622e 100644
--- a/dashboard/client/src/pages/metrics/Metrics.tsx
+++ b/dashboard/client/src/pages/metrics/Metrics.tsx
@@ -20,9 +20,21 @@ const useStyles = makeStyles((theme) =>
display: "flex",
flexDirection: "row",
flexWrap: "wrap",
+ gap: theme.spacing(3),
+ },
+ chart: {
+ flex: "1 0 448px",
+ maxWidth: "100%",
+ height: 300,
+ overflow: "hidden",
+ [theme.breakpoints.up("md")]: {
+ // Calculate max width based on 1/3 of the total width minus padding between cards
+ maxWidth: `calc((100% - ${theme.spacing(3)}px * 2) / 3)`,
+ },
},
grafanaEmbed: {
- margin: theme.spacing(1),
+ width: "100%",
+ height: "100%",
},
topBar: {
position: "sticky",
@@ -213,15 +225,15 @@ export const Metrics = () => {
{METRICS_CONFIG.map(({ title, path }) => (
-
+
+
+
))}
diff --git a/dashboard/client/src/pages/overview/OverviewPage.tsx b/dashboard/client/src/pages/overview/OverviewPage.tsx
index e5faf12dc7e9c..91f347b627743 100644
--- a/dashboard/client/src/pages/overview/OverviewPage.tsx
+++ b/dashboard/client/src/pages/overview/OverviewPage.tsx
@@ -17,17 +17,23 @@ const useStyles = makeStyles((theme) =>
overviewCardsContainer: {
display: "flex",
flexDirection: "row",
- flexWrap: "nowrap",
+ flexWrap: "wrap",
marginBottom: theme.spacing(4),
gap: theme.spacing(3),
+ [theme.breakpoints.up("md")]: {
+ flexWrap: "nowrap",
+ },
},
overviewCard: {
flex: "1 0 448px",
- // Calculate max width based on 1/3 of the total width minus padding between cards
- maxWidth: `calc((100% - ${theme.spacing(3)}px * 2) / 3)`,
+ maxWidth: "100%",
+ [theme.breakpoints.up("md")]: {
+ // Calculate max width based on 1/3 of the total width minus padding between cards
+ maxWidth: `calc((100% - ${theme.spacing(3)}px * 2) / 3)`,
+ },
},
section: {
- marginTop: theme.spacing(2),
+ marginTop: theme.spacing(4),
},
}),
);
@@ -59,6 +65,7 @@ export const OverviewPage = () => {
className={classes.section}
title="Node metrics"
startExpanded
+ keepRendered
>
diff --git a/dashboard/client/src/pages/overview/cards/NodeCountCard.tsx b/dashboard/client/src/pages/overview/cards/NodeCountCard.tsx
index 854d08afa3376..ab53ebc46bbff 100644
--- a/dashboard/client/src/pages/overview/cards/NodeCountCard.tsx
+++ b/dashboard/client/src/pages/overview/cards/NodeCountCard.tsx
@@ -18,6 +18,7 @@ const useStyles = makeStyles((theme) =>
display: "flex",
flexDirection: "row",
flexWrap: "nowrap",
+ margin: theme.spacing(0, 3, 2),
},
}),
);
diff --git a/dashboard/client/src/pages/overview/cards/OverviewCard.tsx b/dashboard/client/src/pages/overview/cards/OverviewCard.tsx
index bdec81b33c7ee..8edae6882702c 100644
--- a/dashboard/client/src/pages/overview/cards/OverviewCard.tsx
+++ b/dashboard/client/src/pages/overview/cards/OverviewCard.tsx
@@ -7,9 +7,8 @@ import { Link } from "react-router-dom";
const useStyles = makeStyles((theme) =>
createStyles({
root: {
- padding: theme.spacing(2, 3),
- height: 270,
- borderRadius: 8,
+ height: 294,
+ overflow: "hidden",
},
}),
);
diff --git a/dashboard/client/src/pages/overview/cards/RecentJobsCard.component.test.tsx b/dashboard/client/src/pages/overview/cards/RecentJobsCard.component.test.tsx
index 5d1da1c2474d6..e1eed0a76fe3e 100644
--- a/dashboard/client/src/pages/overview/cards/RecentJobsCard.component.test.tsx
+++ b/dashboard/client/src/pages/overview/cards/RecentJobsCard.component.test.tsx
@@ -46,6 +46,6 @@ describe("RecentJobsCard", () => {
expect(screen.getByText("02000000")).toBeVisible();
expect(screen.getByText("raysubmit_23456")).toBeVisible();
expect(screen.getByText("04000000")).toBeVisible();
- expect(screen.getByText("05000000")).toBeVisible();
+ expect(screen.queryByText("05000000")).toBeNull();
});
});
diff --git a/dashboard/client/src/pages/overview/cards/RecentJobsCard.tsx b/dashboard/client/src/pages/overview/cards/RecentJobsCard.tsx
index 71c60735cd12b..e79526caf4386 100644
--- a/dashboard/client/src/pages/overview/cards/RecentJobsCard.tsx
+++ b/dashboard/client/src/pages/overview/cards/RecentJobsCard.tsx
@@ -17,6 +17,7 @@ const useStyles = makeStyles((theme) =>
root: {
display: "flex",
flexDirection: "column",
+ padding: theme.spacing(2, 3),
},
listContainer: {
marginTop: theme.spacing(2),
@@ -39,7 +40,7 @@ export const RecentJobsCard = ({ className }: RecentJobsCardProps) => {
const classes = useStyles();
const { jobList } = useJobList();
- const sortedJobs = _.orderBy(jobList, ["startTime"], ["desc"]).slice(0, 5);
+ const sortedJobs = _.orderBy(jobList, ["startTime"], ["desc"]).slice(0, 4);
return (
@@ -108,6 +109,7 @@ const useRecentJobListItemStyles = makeStyles((theme) =>
overflow: "hidden",
textOverflow: "ellipsis",
whiteSpace: "nowrap",
+ color: "#5F6469",
},
}),
);
@@ -148,10 +150,14 @@ const RecentJobListItem = ({ job, className }: RecentJobListItemProps) => {
{icon}
-
+
{job.job_id ?? job.submission_id}
-
+
{job.entrypoint}
diff --git a/dashboard/client/src/pages/state/task.tsx b/dashboard/client/src/pages/state/task.tsx
index a9363ad18525e..8c5aa446e5cad 100644
--- a/dashboard/client/src/pages/state/task.tsx
+++ b/dashboard/client/src/pages/state/task.tsx
@@ -7,7 +7,7 @@ import { Task } from "../../type/task";
import { useStateApiList } from "./hook/useStateApi";
/**
- * Represent the embedable actors page.
+ * Represent the embedable tasks page.
*/
const TaskList = ({ jobId = null }: { jobId?: string | null }) => {
const [timeStamp] = useState(dayjs());
diff --git a/dashboard/client/src/service/placementGroup.ts b/dashboard/client/src/service/placementGroup.ts
index 29d1a383739c1..11eee66c95550 100644
--- a/dashboard/client/src/service/placementGroup.ts
+++ b/dashboard/client/src/service/placementGroup.ts
@@ -4,6 +4,6 @@ import { get } from "./requestHandlers";
export const getPlacementGroup = () => {
return get>(
- "api/v0/placement_groups?detail=1",
+ "api/v0/placement_groups?detail=1&limit=10000",
);
};
diff --git a/dashboard/client/src/service/task.ts b/dashboard/client/src/service/task.ts
index d4e89baa7e66e..50bb04e1e7899 100644
--- a/dashboard/client/src/service/task.ts
+++ b/dashboard/client/src/service/task.ts
@@ -3,5 +3,5 @@ import { Task } from "../type/task";
import { get } from "./requestHandlers";
export const getTasks = () => {
- return get>("api/v0/tasks?detail=1");
+ return get>("api/v0/tasks?detail=1&limit=10000");
};
diff --git a/dashboard/client/src/theme.ts b/dashboard/client/src/theme.ts
index 2f5021a5fbc10..6259f9b65b221 100644
--- a/dashboard/client/src/theme.ts
+++ b/dashboard/client/src/theme.ts
@@ -32,6 +32,14 @@ const basicTheme: ThemeOptions = {
body1: {
fontSize: "0.75rem",
},
+ body2: {
+ fontSize: "14px",
+ lineHeight: "20px",
+ },
+ caption: {
+ fontSize: "0.75rem",
+ lineHeight: "16px",
+ },
},
props: {
MuiPaper: {
@@ -39,6 +47,13 @@ const basicTheme: ThemeOptions = {
},
},
overrides: {
+ MuiCssBaseline: {
+ "@global": {
+ a: {
+ color: "#036DCF",
+ },
+ },
+ },
MuiTooltip: {
tooltip: {
fontSize: "0.75rem",
@@ -56,6 +71,7 @@ const basicTheme: ThemeOptions = {
MuiPaper: {
outlined: {
borderColor: "#D2DCE6",
+ borderRadius: 8,
},
},
},
diff --git a/dashboard/client/src/type/task.ts b/dashboard/client/src/type/task.ts
index b100dcb34a7cc..107e00192536d 100644
--- a/dashboard/client/src/type/task.ts
+++ b/dashboard/client/src/type/task.ts
@@ -24,7 +24,7 @@ export enum TypeTaskType {
export type Task = {
task_id: string;
name: string;
- scheduling_state: TypeTaskStatus;
+ state: TypeTaskStatus;
job_id: string;
node_id: string;
actor_id: string;
@@ -33,4 +33,7 @@ export type Task = {
language: string;
required_resources: { [key: string]: number };
runtime_env_info: string;
+ events: { [key: string]: string }[];
+ start_time_ms: number | null;
+ end_time_ms: number | null;
};
diff --git a/dashboard/dashboard.py b/dashboard/dashboard.py
index 406bfaebfe638..55f465bf024ff 100644
--- a/dashboard/dashboard.py
+++ b/dashboard/dashboard.py
@@ -35,6 +35,8 @@ class Dashboard:
port: Port number of dashboard aiohttp server.
port_retries: The retry times to select a valid port.
gcs_address: GCS address of the cluster
+ serve_frontend: If configured, frontend HTML
+ is not served from the dashboard.
log_dir: Log directory of dashboard.
"""
@@ -48,6 +50,7 @@ def __init__(
temp_dir: str = None,
session_dir: str = None,
minimal: bool = False,
+ serve_frontend: bool = True,
modules_to_load: Optional[Set[str]] = None,
):
self.dashboard_head = dashboard_head.DashboardHead(
@@ -59,6 +62,7 @@ def __init__(
temp_dir=temp_dir,
session_dir=session_dir,
minimal=minimal,
+ serve_frontend=serve_frontend,
modules_to_load=modules_to_load,
)
@@ -166,6 +170,11 @@ async def run(self):
"If nothing is specified, all modules are loaded."
),
)
+ parser.add_argument(
+ "--disable-frontend",
+ action="store_true",
+ help=("If configured, frontend html is not served from the server."),
+ )
args = parser.parse_args()
@@ -190,7 +199,6 @@ async def run(self):
# which assumes a working event loop. Ref:
# https://github.com/grpc/grpc/blob/master/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi#L174-L188
loop = ray._private.utils.get_or_create_event_loop()
-
dashboard = Dashboard(
args.host,
args.port,
@@ -200,6 +208,7 @@ async def run(self):
temp_dir=args.temp_dir,
session_dir=args.session_dir,
minimal=args.minimal,
+ serve_frontend=(not args.disable_frontend),
modules_to_load=modules_to_load,
)
diff --git a/dashboard/head.py b/dashboard/head.py
index fe4aec642f248..e28c2cc5a32b5 100644
--- a/dashboard/head.py
+++ b/dashboard/head.py
@@ -78,6 +78,7 @@ def __init__(
temp_dir: str,
session_dir: str,
minimal: bool,
+ serve_frontend: bool,
modules_to_load: Optional[Set[str]] = None,
):
"""
@@ -90,12 +91,18 @@ def __init__(
temp_dir: The temp directory. E.g., /tmp.
session_dir: The session directory. E.g., tmp/session_latest.
minimal: Whether or not it will load the minimal modules.
+ serve_frontend: If configured, frontend HTML is
+ served from the dashboard.
modules_to_load: A set of module name in string to load.
By default (None), it loads all available modules.
Note that available modules could be changed depending on
minimal flags.
"""
self.minimal = minimal
+ self.serve_frontend = serve_frontend
+ # If it is the minimal mode, we shouldn't serve frontend.
+ if self.minimal:
+ self.serve_frontend = False
self.health_check_thread: GCSHealthCheckThread = None
self._gcs_rpc_error_counter = 0
# Public attributes are accessible for all head modules.
@@ -290,9 +297,12 @@ async def _async_notify():
modules = self._load_modules(self._modules_to_load)
http_host, http_port = self.http_host, self.http_port
- if not self.minimal:
+ if self.serve_frontend:
+ logger.info("Initialize the http server.")
self.http_server = await self._configure_http_server(modules)
http_host, http_port = self.http_server.get_address()
+ else:
+ logger.info("http server disabled.")
await asyncio.gather(
self.gcs_aio_client.internal_kv_put(
ray_constants.DASHBOARD_ADDRESS.encode(),
diff --git a/dashboard/modules/metrics/grafana_dashboard_base.json b/dashboard/modules/metrics/grafana_dashboard_base.json
index ac5583e8216e0..6edde3d5b8d6c 100644
--- a/dashboard/modules/metrics/grafana_dashboard_base.json
+++ b/dashboard/modules/metrics/grafana_dashboard_base.json
@@ -55,7 +55,7 @@
"useTags": false
},
{
- "allValue": null,
+ "allValue": ".+",
"current": {
"selected": true,
"text": [
diff --git a/dashboard/state_aggregator.py b/dashboard/state_aggregator.py
index 46d37d7e880c4..ebec09d7b9177 100644
--- a/dashboard/state_aggregator.py
+++ b/dashboard/state_aggregator.py
@@ -1,5 +1,6 @@
import asyncio
import logging
+import json
from dataclasses import asdict, fields
from itertools import islice
@@ -209,7 +210,13 @@ async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse:
for message in reply.actor_table_data:
data = self._message_to_dict(
message=message,
- fields_to_decode=["actor_id", "owner_id", "job_id", "node_id"],
+ fields_to_decode=[
+ "actor_id",
+ "owner_id",
+ "job_id",
+ "node_id",
+ "placement_group_id",
+ ],
)
result.append(data)
num_after_truncation = len(result)
@@ -279,6 +286,9 @@ async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse:
for message in reply.node_info_list:
data = self._message_to_dict(message=message, fields_to_decode=["node_id"])
data["node_ip"] = data["node_manager_address"]
+ data["start_time_ms"] = int(data["start_time_ms"])
+ data["end_time_ms"] = int(data["end_time_ms"])
+
result.append(data)
total_nodes = len(result)
@@ -318,6 +328,8 @@ async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse:
data["worker_id"] = data["worker_address"]["worker_id"]
data["node_id"] = data["worker_address"]["raylet_id"]
data["ip"] = data["worker_address"]["ip_address"]
+ data["start_time_ms"] = int(data["start_time_ms"])
+ data["end_time_ms"] = int(data["end_time_ms"])
result.append(data)
num_after_truncation = len(result)
@@ -359,9 +371,16 @@ async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse:
{task_id -> task_data_in_dict}
task_data_in_dict's schema is in TaskState
"""
+ job_id = None
+ for filter in option.filters:
+ if filter[0] == "job_id":
+ # tuple consists of (job_id, predicate, value)
+ job_id = filter[2]
try:
- reply = await self._client.get_all_task_info(timeout=option.timeout)
+ reply = await self._client.get_all_task_info(
+ timeout=option.timeout, job_id=job_id
+ )
except DataSourceUnavailable:
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
@@ -371,7 +390,15 @@ def _to_task_state(task_attempt: dict) -> dict:
"""
task_state = {}
task_info = task_attempt.get("task_info", {})
- state_updates = task_attempt.get("state_updates", None)
+ state_updates = task_attempt.get("state_updates", [])
+ profiling_data = task_attempt.get("profiling_data", {})
+ if profiling_data:
+ for event in profiling_data["events"]:
+ # End/start times are recorded in ns. We convert them to ms.
+ event["end_time"] = int(event["end_time"]) / 1e6
+ event["start_time"] = int(event["start_time"]) / 1e6
+ event["extra_data"] = json.loads(event["extra_data"])
+ task_state["profiling_data"] = profiling_data
# Convert those settable fields
mappings = [
@@ -387,25 +414,42 @@ def _to_task_state(task_attempt: dict) -> dict:
"required_resources",
"runtime_env_info",
"parent_task_id",
+ "placement_group_id",
],
),
(task_attempt, ["task_id", "attempt_number", "job_id"]),
- (state_updates, ["node_id"]),
+ (state_updates, ["node_id", "worker_id"]),
]
for src, keys in mappings:
for key in keys:
task_state[key] = src.get(key)
- # Get the most updated scheduling_state by state transition ordering.
- def _get_most_recent_status(task_state: dict) -> str:
- # Reverse the order as defined in protobuf for the most recent state.
- for status_name in reversed(common_pb2.TaskStatus.keys()):
- key = f"{status_name.lower()}_ts"
- if state_updates.get(key):
- return status_name
- return common_pb2.TaskStatus.Name(common_pb2.NIL)
-
- task_state["scheduling_state"] = _get_most_recent_status(state_updates)
+ task_state["start_time_ms"] = None
+ task_state["end_time_ms"] = None
+ events = []
+
+ for state in common_pb2.TaskStatus.keys():
+ key = f"{state.lower()}_ts"
+ if key in state_updates:
+ # timestamp is recorded in ns.
+ ts_ms = int(state_updates[key]) // 1e6
+ events.append(
+ {
+ "state": state,
+ "created_ms": ts_ms,
+ }
+ )
+ if state == "RUNNING":
+ task_state["start_time_ms"] = ts_ms
+ if state == "FINISHED" or state == "FAILED":
+ task_state["end_time_ms"] = ts_ms
+
+ task_state["events"] = events
+ if len(events) > 0:
+ latest_state = events[-1]["state"]
+ else:
+ latest_state = common_pb2.TaskStatus.Name(common_pb2.NIL)
+ task_state["state"] = latest_state
return task_state
@@ -419,6 +463,9 @@ def _get_most_recent_status(task_state: dict) -> str:
"node_id",
"actor_id",
"parent_task_id",
+ "worker_id",
+ "placement_group_id",
+ "component_id",
],
)
)
diff --git a/dashboard/tests/test_dashboard.py b/dashboard/tests/test_dashboard.py
index 36925f72ba04f..f0107daebbff4 100644
--- a/dashboard/tests/test_dashboard.py
+++ b/dashboard/tests/test_dashboard.py
@@ -19,6 +19,8 @@
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.modules
import ray.dashboard.utils as dashboard_utils
+from click.testing import CliRunner
+from requests.exceptions import ConnectionError
from ray._private import ray_constants
from ray._private.ray_constants import (
DEBUG_AUTOSCALING_ERROR,
@@ -34,6 +36,7 @@
wait_until_server_available,
wait_until_succeeded_without_exception,
)
+import ray.scripts.scripts as scripts
from ray.dashboard import dashboard
from ray.dashboard.head import DashboardHead
from ray.experimental.state.api import StateApiClient
@@ -900,7 +903,7 @@ def test_agent_does_not_depend_on_serve(shutdown_only):
os.environ.get("RAY_MINIMAL") == "1" or os.environ.get("RAY_DEFAULT") == "1",
reason="This test is not supposed to work for minimal or default installation.",
)
-def test_agent_port_conflict():
+def test_agent_port_conflict(shutdown_only):
ray.shutdown()
# start ray and test agent works.
@@ -989,6 +992,7 @@ def test_dashboard_module_load(tmpdir):
str(tmpdir),
str(tmpdir),
False,
+ True,
)
# Test basic.
@@ -1045,5 +1049,75 @@ def test_dashboard_module_no_warnings(enable_test_module):
debug._disabled = old_val
+def test_dashboard_not_included_ray_init(shutdown_only, capsys):
+ addr = ray.init(include_dashboard=False, dashboard_port=8265)
+ dashboard_url = addr["webui_url"]
+ assert "View the dashboard" not in capsys.readouterr().err
+ assert not dashboard_url
+
+ # Warm up.
+ @ray.remote
+ def f():
+ pass
+
+ ray.get(f.remote())
+
+ with pytest.raises(ConnectionError):
+ # Since the dashboard doesn't start, it should raise ConnectionError
+ # becasue we cannot estabilish a connection.
+ requests.get("http://localhost:8265")
+
+
+def test_dashboard_not_included_ray_start(shutdown_only, capsys):
+ runner = CliRunner()
+ try:
+ runner.invoke(
+ scripts.start,
+ ["--head", "--include-dashboard=False", "--dashboard-port=8265"],
+ )
+ addr = ray.init("auto")
+ dashboard_url = addr["webui_url"]
+ assert not dashboard_url
+
+ assert "view the dashboard at" not in capsys.readouterr().err
+
+ # Warm up.
+ @ray.remote
+ def f():
+ pass
+
+ ray.get(f.remote())
+
+ with pytest.raises(ConnectionError):
+ # Since the dashboard doesn't start, it should raise ConnectionError
+ # becasue we cannot estabilish a connection.
+ requests.get("http://localhost:8265")
+ finally:
+ runner.invoke(scripts.stop, ["--force"])
+
+
+@pytest.mark.skipif(
+ os.environ.get("RAY_MINIMAL") != "1",
+ reason="This test only works for minimal installation.",
+)
+def test_dashboard_not_included_ray_minimal(shutdown_only, capsys):
+ addr = ray.init(dashboard_port=8265)
+ dashboard_url = addr["webui_url"]
+ assert "View the dashboard" not in capsys.readouterr().err
+ assert not dashboard_url
+
+ # Warm up.
+ @ray.remote
+ def f():
+ pass
+
+ ray.get(f.remote())
+
+ with pytest.raises(ConnectionError):
+ # Since the dashboard doesn't start, it should raise ConnectionError
+ # becasue we cannot estabilish a connection.
+ requests.get("http://localhost:8265")
+
+
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
diff --git a/doc/source/_static/js/custom.js b/doc/source/_static/js/custom.js
index e745f4e140054..7db383309c9be 100644
--- a/doc/source/_static/js/custom.js
+++ b/doc/source/_static/js/custom.js
@@ -27,3 +27,19 @@ function loadVisibleTermynals() {
window.addEventListener("scroll", loadVisibleTermynals);
createTermynals();
loadVisibleTermynals();
+
+// Remember the scroll position when the page is unloaded.
+window.onload = function() {
+ let sidebar = document.querySelector("#bd-docs-nav");
+
+ window.onbeforeunload = function() {
+ let scroll = sidebar.scrollTop;
+ localStorage.setItem("scroll", scroll);
+ }
+
+ let storedScrollPosition = localStorage.getItem("scroll");
+ if (storedScrollPosition) {
+ sidebar.scrollTop = storedScrollPosition;
+ localStorage.removeItem("scroll");
+ }
+};
diff --git a/doc/source/_toc.yml b/doc/source/_toc.yml
index bbc60b28255d6..48ba24c5daa55 100644
--- a/doc/source/_toc.yml
+++ b/doc/source/_toc.yml
@@ -236,6 +236,7 @@ parts:
- file: rllib/rllib-replay-buffers
- file: rllib/rllib-offline
- file: rllib/rllib-connector
+ - file: rllib/rllib-fault-tolerance
- file: rllib/rllib-dev
- file: rllib/rllib-cli
- file: rllib/rllib-examples
diff --git a/doc/source/cluster/kubernetes/user-guides/config.md b/doc/source/cluster/kubernetes/user-guides/config.md
index 3dd7b14f0d56e..49603e20753b6 100644
--- a/doc/source/cluster/kubernetes/user-guides/config.md
+++ b/doc/source/cluster/kubernetes/user-guides/config.md
@@ -212,6 +212,7 @@ For most use-cases, this field should be set to "0.0.0.0" for the Ray head pod.
This is required to expose the Ray dashboard outside the Ray cluster. (Future versions might set
this parameter by default.)
+(kuberay-num-cpus)=
### num-cpus
This optional field tells the Ray scheduler and autoscaler how many CPUs are
available to the Ray pod. The CPU count can be autodetected from the
diff --git a/doc/source/cluster/vms/user-guides/community/spark.rst b/doc/source/cluster/vms/user-guides/community/spark.rst
index 9192afed1ace0..06c765ed6430b 100644
--- a/doc/source/cluster/vms/user-guides/community/spark.rst
+++ b/doc/source/cluster/vms/user-guides/community/spark.rst
@@ -17,7 +17,7 @@ Assuming the python file name is 'ray-on-spark-example1.py'.
.. code-block:: python
from pyspark.sql import SparkSession
- from ray.util.spark import init_ray_cluster, shutdown_ray_cluster, MAX_NUM_WORKER_NODES
+ from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster, MAX_NUM_WORKER_NODES
if __name__ == "__main__":
spark = SparkSession \
.builder \
@@ -25,21 +25,23 @@ Assuming the python file name is 'ray-on-spark-example1.py'.
.config("spark.task.cpus", "4") \
.getOrCreate()
- # initiate a ray cluster on this spark application, it creates a background
+ # Set up a ray cluster on this spark application, it creates a background
# spark job that each spark task launches one ray worker node.
# ray head node is launched in spark application driver side.
# Resources (CPU / GPU / memory) allocated to each ray worker node is equal
# to resources allocated to the corresponding spark task.
- init_ray_cluster(num_worker_nodes=MAX_NUM_WORKER_NODES)
+ setup_ray_cluster(num_worker_nodes=MAX_NUM_WORKER_NODES)
# You can any ray application code here, the ray application will be executed
# on the ray cluster setup above.
- # Note that you don't need to call `ray.init`.
+ # You don't need to set address for `ray.init`,
+ # it will connect to the cluster created above automatically.
+ ray.init()
...
# Terminate ray cluster explicitly.
- # If you don't call it, when spark application is terminated, the ray cluster will
- # also be terminated.
+ # If you don't call it, when spark application is terminated, the ray cluster
+ # will also be terminated.
shutdown_ray_cluster()
2) Submit the spark application above to spark standalone cluster.
@@ -64,7 +66,7 @@ Assuming the python file name is 'long-running-ray-cluster-on-spark.py'.
from pyspark.sql import SparkSession
import time
- from ray.util.spark import init_ray_cluster, MAX_NUM_WORKER_NODES
+ from ray.util.spark import setup_ray_cluster, MAX_NUM_WORKER_NODES
if __name__ == "__main__":
spark = SparkSession \
@@ -73,9 +75,11 @@ Assuming the python file name is 'long-running-ray-cluster-on-spark.py'.
.config("spark.task.cpus", "4") \
.getOrCreate()
- cluster_address = init_ray_cluster(num_worker_nodes=MAX_NUM_WORKER_NODES)
- print("Ray cluster is initiated, you can connect to this ray cluster via address "
- f"ray://{cluster_address}")
+ cluster_address = setup_ray_cluster(
+ num_worker_nodes=MAX_NUM_WORKER_NODES
+ )
+ print("Ray cluster is set up, you can connect to this ray cluster "
+ f"via address ray://{cluster_address}")
# Sleep forever until the spark application being terminated,
# at that time, the ray cluster will also be terminated.
@@ -90,3 +94,10 @@ Assuming the python file name is 'long-running-ray-cluster-on-spark.py'.
spark-submit \
--master spark://{spark_master_IP}:{spark_master_port} \
path/to/long-running-ray-cluster-on-spark.py
+
+Ray on Spark APIs
+-----------------
+
+.. autofunction:: ray.util.spark.setup_ray_cluster
+
+.. autofunction:: ray.util.spark.shutdown_ray_cluster
diff --git a/doc/source/cluster/vms/user-guides/large-cluster-best-practices.rst b/doc/source/cluster/vms/user-guides/large-cluster-best-practices.rst
index 312e785886044..49a031bef4dfd 100644
--- a/doc/source/cluster/vms/user-guides/large-cluster-best-practices.rst
+++ b/doc/source/cluster/vms/user-guides/large-cluster-best-practices.rst
@@ -56,10 +56,13 @@ architecture means that the head node will have extra stress due to GCS.
resource on the head node is outbound bandwidth. For large clusters (see the
scalability envelope), we recommend using machines networking characteristics
at least as good as an r5dn.16xlarge on AWS EC2.
-* Set ``resources: {"CPU": 0}`` on the head node. (For Ray clusters deployed using Helm,
- set ``rayResources: {"CPU": 0}``.) Due to the heavy networking
- load (and the GCS and dashboard processes), we recommend setting the number of
- CPUs to 0 on the head node to avoid scheduling additional tasks on it.
+* Set ``resources: {"CPU": 0}`` on the head node.
+ (For Ray clusters deployed using KubeRay,
+ set ``rayStartParams: {"num-cpus": "0"}``.
+ See the :ref:`configuration guide for KubeRay clusters `.)
+ Due to the heavy networking load (and the GCS and dashboard processes), we
+ recommend setting the number of CPUs to 0 on the head node to avoid
+ scheduling additional tasks on it.
Configuring the autoscaler
^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/doc/source/ray-observability/ray-metrics.rst b/doc/source/ray-observability/ray-metrics.rst
index d5feaf979bbe6..d05363102cbe2 100644
--- a/doc/source/ray-observability/ray-metrics.rst
+++ b/doc/source/ray-observability/ray-metrics.rst
@@ -69,6 +69,10 @@ allows you to create custom dashboards with your favorite metrics. Ray exports s
configurations which includes a default dashboard showing some of the most valuable metrics
for debugging ray applications.
+
+Deploying Grafana
+~~~~~~~~~~~~~~~~~
+
First, `download Grafana `_. Follow the instructions on the download page to download the right binary for your operating system.
Then go to to the location of the binary and run grafana using the built in configuration found in `/tmp/ray/session_latest/metrics/grafana` folder.
@@ -87,6 +91,41 @@ You can then see the default dashboard by going to dashboards -> manage -> Ray -
.. image:: images/graphs.png
:align: center
+Using an existing Grafana instance
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+When you want to use existing Grafana instance, before starting your Ray cluster you will need to setup environment variable `RAY_GRAFANA_HOST` with an URL of your Grafana. After starting Ray, you can find Grafana dashboard json at `/tmp/ray/session_latest/metrics/grafana/dashboards/default_grafana_dashboard.json`. `Import this dashboard `_ to your Grafana.
+
+If Grafana reports that datasource is not found, you can `add a datasource variable `_ and using `JSON model view `_ change all values of `datasource` key in the imported `default_grafana_dashboard.json` to the name of the variable. For example, if the variable name is `data_source`, all `"datasource"` mappings should be:
+
+.. code-block:: json
+
+ "datasource": {
+ "type": "prometheus",
+ "uid": "$data_source"
+ }
+
+When existing Grafana instance requires user authentication, the following settings have to be in its `configuration file `_ to correctly embed in Ray dashboard:
+
+.. code-block:: ini
+
+ [security]
+ allow_embedding = true
+ cookie_secure = true
+ cookie_samesite = none
+
+If Grafana is exposed via nginx ingress on Kubernetes cluster, the following line should be present in the Grafana ingress annotation:
+
+.. code-block:: yaml
+
+ nginx.ingress.kubernetes.io/configuration-snippet: |
+ add_header X-Frame-Options SAMEORIGIN always;
+
+When both Grafana and Ray cluster are on the same Kubernetes cluster, it is important to set `RAY_GRAFANA_HOST` to the external URL of the Grafana ingress. For successful embedding, `RAY_GRAFANA_HOST` needs to be accessible to both Ray cluster backend and Ray dashboard frontend:
+
+* On the backend, *Ray cluster head* does health checks on Grafana. Hence `RAY_GRAFANA_HOST` needs to be accessible in the Kubernetes pod which is running the head node.
+* When accessing *Ray dashboard* from the browser, frontend embeds Grafana dashboard using the URL specified in `RAY_GRAFANA_HOST`. Hence `RAY_GRAFANA_HOST` needs to be accessible from the browser as well.
+
.. _system-metrics:
System Metrics
@@ -240,3 +279,9 @@ When downloading binaries from the internet, Mac requires that the binary be sig
Unfortunately, many developers today are not trusted by Mac and so this requirement must be overridden by the user manaully.
See `these instructions `_ on how to override the restriction and install or run the application.
+
+Grafana dashboards are not embedded in the Ray dashboard
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+If you're getting error that `RAY_GRAFANA_HOST` is not setup despite you've set it up, please check:
+That you've included protocol in the URL (e.g. `http://your-grafana-url.com` instead of `your-grafana-url.com`).
+Also, make sure that url doesn't have trailing slash (e.g. `http://your-grafana-url.com` instead of `http://your-grafana-url.com/`).
\ No newline at end of file
diff --git a/doc/source/ray-observability/state/state-api.rst b/doc/source/ray-observability/state/state-api.rst
index d6a87a7c46588..04bd682e6d07e 100644
--- a/doc/source/ray-observability/state/state-api.rst
+++ b/doc/source/ray-observability/state/state-api.rst
@@ -320,14 +320,14 @@ E.g., List running tasks
.. code-block:: bash
- ray list tasks -f scheduling_state=RUNNING
+ ray list tasks -f state=RUNNING
.. tabbed:: Python SDK
.. code-block:: python
from ray.experimental.state.api import list_tasks
- list_tasks(filters=[("scheduling_state", "=", "RUNNING")])
+ list_tasks(filters=[("state", "=", "RUNNING")])
E.g., List non-running tasks
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -336,14 +336,14 @@ E.g., List non-running tasks
.. code-block:: bash
- ray list tasks -f scheduling_state!=RUNNING
+ ray list tasks -f state!=RUNNING
.. tabbed:: Python SDK
.. code-block:: python
from ray.experimental.state.api import list_tasks
- list_tasks(filters=[("scheduling_state", "!=", "RUNNING")])
+ list_tasks(filters=[("state", "!=", "RUNNING")])
E.g., List running tasks that have a name func
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -352,14 +352,14 @@ E.g., List running tasks that have a name func
.. code-block:: bash
- ray list tasks -f scheduling_state=RUNNING -f name="task_running_300_seconds()"
+ ray list tasks -f state=RUNNING -f name="task_running_300_seconds()"
.. tabbed:: Python SDK
.. code-block:: python
from ray.experimental.state.api import list_tasks
- list_tasks(filters=[("scheduling_state", "=", "RUNNING"), ("name", "=", "task_running_300_seconds()")])
+ list_tasks(filters=[("state", "=", "RUNNING"), ("name", "=", "task_running_300_seconds()")])
E.g., List tasks with more details
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/doc/source/ray-overview/eco-gallery.yml b/doc/source/ray-overview/eco-gallery.yml
index a9dda3dae359c..ff47b975a1782 100644
--- a/doc/source/ray-overview/eco-gallery.yml
+++ b/doc/source/ray-overview/eco-gallery.yml
@@ -180,3 +180,11 @@ projects:
website: https://github.com/ray-project/lightgbm_ray
repo: https://github.com/ray-project/lightgbm_ray
image: ../images/lightgbm_logo.png
+ - name: Volcano Integration
+ section_title: Volcano
+ description: Volcano is system for running high-performance workloads
+ on Kubernetes. It features powerful batch scheduling capabilities required by ML
+ and other data-intensive workloads.
+ website: https://github.com/volcano-sh/volcano/releases/tag/v1.7.0
+ repo: https://github.com/volcano-sh/volcano/
+ image: ./images/volcano.png
diff --git a/doc/source/ray-overview/images/volcano.png b/doc/source/ray-overview/images/volcano.png
new file mode 100644
index 0000000000000..dd413704009b9
Binary files /dev/null and b/doc/source/ray-overview/images/volcano.png differ
diff --git a/doc/source/ray-overview/installation.rst b/doc/source/ray-overview/installation.rst
index 28825058a82a9..3c7f41dbe18e2 100644
--- a/doc/source/ray-overview/installation.rst
+++ b/doc/source/ray-overview/installation.rst
@@ -11,8 +11,8 @@ Official Releases
From Wheels
~~~~~~~~~~~
-You can install the latest official version of Ray from PyPI on linux, windows
-and macos as follows:
+You can install the latest official version of Ray from PyPI on Linux, Windows
+and macOS as follows:
.. code-block:: bash
@@ -119,7 +119,7 @@ Here's a summary of the variations:
* For Python 3.8 and 3.9, the ``m`` before the OS version should be deleted and the OS version for MacOS should read ``macosx_10_15_x86_64`` instead of ``macosx_10_15_intel``.
-* For MacOS, commits predating August 7, 2021 will have ``macosx_10_13`` in the filename instad of ``macosx_10_15``.
+* For MacOS, commits predating August 7, 2021 will have ``macosx_10_13`` in the filename instead of ``macosx_10_15``.
.. _ray-install-java:
@@ -196,7 +196,7 @@ You can install and use Ray C++ API as follows.
.. note::
- If you build Ray from source, please remove the build option ``build --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"`` from the file ``cpp/example/.bazelrc`` before running your application. The related issue is `this `_.
+ If you build Ray from source, remove the build option ``build --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"`` from the file ``cpp/example/.bazelrc`` before running your application. The related issue is `this `_.
.. _apple-silcon-supprt:
@@ -268,7 +268,7 @@ on the AUR page of ``python-ray`` `here`_.
Installing From conda-forge
---------------------------
-Ray can also be installed as a conda package on linux and windows
+Ray can also be installed as a conda package on Linux and Windows.
.. code-block:: bash
@@ -282,7 +282,7 @@ Ray can also be installed as a conda package on linux and windows
# Install Ray with minimal dependencies
# conda install -c conda-forge ray
-To install Ray libraries, you can use ``pip`` as above or ``conda``/``mamba``
+To install Ray libraries, use ``pip`` as above or ``conda``/``mamba``.
.. code-block:: bash
diff --git a/doc/source/rllib/rllib-fault-tolerance.rst b/doc/source/rllib/rllib-fault-tolerance.rst
new file mode 100644
index 0000000000000..24e64fee9dff8
--- /dev/null
+++ b/doc/source/rllib/rllib-fault-tolerance.rst
@@ -0,0 +1,91 @@
+.. include:: /_includes/rllib/announcement.rst
+
+.. include:: /_includes/rllib/we_are_hiring.rst
+
+Fault Tolerance And Elastic Training
+====================================
+
+RLlib handles common failures modes, such as machine failures, spot instance preemption,
+network outages, or Ray cluster failures.
+
+There are three main areas for RLlib fault tolerance support:
+
+* Worker recovery
+* Environment fault tolerance
+* Experiment level fault tolerance with Ray Tune
+
+
+Worker Recovery
+---------------
+
+RLlib supports self-recovering and elastic WorkerSets for both
+:ref:`rollout and evaluation Workers `.
+This provides fault tolerance at worker level.
+
+This means that if you have rollout workers sitting on different machines and a
+machine is pre-empted, RLlib can continue training and evaluation with minimal interruption.
+
+The two properties that RLlib supports here are self-recovery and elasticity:
+
+* **Elasticity**: RLlib continues training even when workers are removed. For example, if an RLlib trial uses spot instances, nodes may be removed from the cluster, potentially resulting in a subset of workers not getting scheduled. In this case, RLlib will continue with whatever healthy workers left at a reduced speed.
+* **Self-Recovery**: When possible, RLlib will attempt to restore workers that were previously removed. During restoration, RLlib sync the latest state before new episodes can be sampled.
+
+
+Worker fault tolerance can be turned on by setting config ``recreate_failed_workers`` to True.
+
+RLlib achieves this by utilizing a
+`state-aware and fault tolerant actor manager `__. Under the hood, RLlib relies on Ray Core :ref:`actor fault tolerance ` to automatically recover failed worker actors.
+
+Env Fault Tolerance
+-------------------
+
+In addition to worker fault tolerance, RLlib offers fault tolerance at environment level as well.
+
+Rollout or evaluation workers will often run multiple environments in parallel to take
+advantage of, for example, the parallel computing power that GPU offers. This can be controlled with
+the ``num_envs_per_worker`` config. It may then be wasteful if the entire worker needs to be
+reconstructed because of errors from a single environment.
+
+In that case, RLlib offers the capability to restart individual environments without bubbling the
+errors to higher level components. You can do that easily by turning on config
+``restart_failed_sub_environments``.
+
+.. note::
+ Environment restarts are blocking.
+
+ A rollout worker will wait until the environment comes back and finishes initialization.
+ So for on-policy algorithms, it may be better to recover at worker level to make sure
+ training progresses with elastic worker set while the environments are being reconstructed.
+ More specifically, use configs ``num_envs_per_worker=1``, ``restart_failed_sub_environments=False``,
+ and ``recreate_failed_workers=True``.
+
+
+Fault Tolerance and Recovery Provided by Ray Tune
+-------------------------------------------------
+
+Ray Tune provides fault tolerance and recovery at the experiment trial level.
+
+When using Ray Tune with RLlib, you can enable
+:ref:`periodic checkpointing `,
+which saves the state of the experiment to a user-specified persistent storage location.
+If a trial fails, Ray Tune will automatically restart it from the latest
+:ref:`checkpointed ` state.
+
+
+Other Miscellaneous Considerations
+----------------------------------
+
+By default, RLlib runs health checks during initial worker construction.
+The whole job will error out if a completely healthy worker fleet can not be established
+at the start of a training run. If an environment is by nature flaky, you may want to turn
+off this feature by setting config ``validate_workers_after_construction`` to False.
+
+Lastly, in an extreme case where no healthy workers are left for training, RLlib will wait
+certain number of iterations for some of the workers to recover before the entire training
+job failed.
+The number of iterations it waits can be configured with the config
+``num_consecutive_worker_failures_tolerance``.
+
+..
+ TODO(jungong) : move fault tolerance related options into a separate AlgorithmConfig
+ group and update the doc here.
diff --git a/doc/source/serve/doc_code/gradio_dag_visualize.py b/doc/source/serve/doc_code/gradio_dag_visualize.py
new file mode 100644
index 0000000000000..a64265689fcb0
--- /dev/null
+++ b/doc/source/serve/doc_code/gradio_dag_visualize.py
@@ -0,0 +1,64 @@
+# __doc_import_begin__
+import requests
+from transformers import pipeline
+from io import BytesIO
+from PIL import Image, ImageFile
+from typing import Dict
+
+from ray import serve
+from ray.dag.input_node import InputNode
+from ray.serve.drivers import DAGDriver
+
+# __doc_import_end__
+
+
+# __doc_downloader_begin__
+@serve.deployment
+def downloader(image_url: str) -> ImageFile.ImageFile:
+ image_bytes = requests.get(image_url).content
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
+ return image
+ # __doc_downloader_end__
+
+
+# __doc_classifier_begin__
+@serve.deployment
+class ImageClassifier:
+ def __init__(self):
+ self.model = pipeline(
+ "image-classification", model="google/vit-base-patch16-224"
+ )
+
+ def classify(self, image: ImageFile.ImageFile) -> Dict[str, float]:
+ results = self.model(image)
+ return {pred["label"]: pred["score"] for pred in results}
+ # __doc_classifier_end__
+
+
+# __doc_translator_begin__
+@serve.deployment
+class Translator:
+ def __init__(self):
+ self.model = pipeline("translation_en_to_de", model="t5-small")
+
+ def translate(self, dict: Dict[str, float]) -> Dict[str, float]:
+ results = {}
+ for label, score in dict.items():
+ translated_label = self.model(label)[0]["translation_text"]
+ results[translated_label] = score
+
+ return results
+ # __doc_translator_end__
+
+
+# __doc_build_graph_begin__
+with InputNode(input_type=str) as user_input:
+ classifier = ImageClassifier.bind()
+ translator = Translator.bind()
+
+ downloaded_image = downloader.bind(user_input)
+ classes = classifier.classify.bind(downloaded_image)
+ translated_classes = translator.translate.bind(classes)
+
+ serve_entrypoint = DAGDriver.bind(translated_classes)
+ # __doc_build_graph_end__
diff --git a/doc/source/serve/model_composition.md b/doc/source/serve/model_composition.md
index d77c01e362b28..d07dcc68211ee 100644
--- a/doc/source/serve/model_composition.md
+++ b/doc/source/serve/model_composition.md
@@ -324,6 +324,7 @@ $ python arithmetic.py
9
```
+(pydot-visualize-dag)=
### Visualizing the Graph
You can render an illustration of your deployment graph to see its nodes and their connection.
@@ -393,6 +394,9 @@ On the other hand, when the script visualizes the final graph output, `combine_o

+#### Visualizing the Graph with Gradio
+Another option is to visualize your deployment graph through Gradio. Check out the [Graph Visualization with Gradio Tutorial](serve-gradio-dag-visualization) to learn how to interactively run your deployment graph through the Gradio UI and see the intermediate outputs of each node in real time as they finish evaluation.
+
## Next Steps
To learn more about deployment graphs, check out some [deployment graph patterns](serve-deployment-graph-patterns-overview) you can incorporate into your own graph!
diff --git a/doc/source/serve/performance.md b/doc/source/serve/performance.md
index b11107563e9ce..62aed1f8b1cb0 100644
--- a/doc/source/serve/performance.md
+++ b/doc/source/serve/performance.md
@@ -144,8 +144,10 @@ There are handful of ways to address these issues:
* Are you reserving GPUs for your deployment replicas using `ray_actor_options` (e.g. `ray_actor_options={“num_gpus”: 1}`)?
* Are you reserving one or more cores for your deployment replicas using `ray_actor_options` (e.g. `ray_actor_options={“num_cpus”: 2}`)?
* Are you setting [OMP_NUM_THREADS](serve-omp-num-threads) to increase the performance of your deep learning framework?
-2. Consider using `async` methods in your callable. See [the section below](serve-performance-async-methods).
-3. Consider batching your requests. See [the section below](serve-performance-batching-requests).
+2. Try batching your requests. See [the section above](serve-performance-batching-requests).
+3. Consider using `async` methods in your callable. See [the section below](serve-performance-async-methods).
+4. Set an end-to-end timeout for your HTTP requests. See [the section below](serve-performance-e2e-timeout).
+
(serve-performance-async-methods)=
### Using `async` methods
@@ -159,3 +161,10 @@ hitting the same queuing issue mentioned above, you might want to increase
`max_concurrent_queries`. Serve sets a low number (100) by default so the client gets
proper backpressure. You can increase the value in the deployment decorator; e.g.
`@serve.deployment(max_concurrent_queries=1000)`.
+
+(serve-performance-e2e-timeout)=
+### Set an end-to-end request timeout
+
+By default, Serve lets client HTTP requests run to completion no matter how long they take. However, slow requests could bottleneck the replica processing, blocking other requests that are waiting. It's recommended that you set an end-to-end timeout, so slow requests can be terminated and retried at another replica.
+
+You can set an end-to-end timeout for HTTP requests by setting the `RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S` environment variable. HTTP Proxies will wait for that many seconds before terminating an HTTP request and retrying it at another replica. This environment variable should be set on every node in your Ray cluster, and it cannot be updated during runtime.
diff --git a/doc/source/serve/tutorials/gradio-dag-visualization.md b/doc/source/serve/tutorials/gradio-dag-visualization.md
new file mode 100644
index 0000000000000..5e34ad62ddb7a
--- /dev/null
+++ b/doc/source/serve/tutorials/gradio-dag-visualization.md
@@ -0,0 +1,206 @@
+(serve-gradio-dag-visualization)=
+# Visualizing a Deployment Graph with Gradio
+
+You can visualize the [deployment graph](serve-model-composition-deployment-graph) you built with [Gradio](https://gradio.app/). This integration allows you to interactively run your deployment graph through the Gradio UI and see the intermediate outputs of each node in real time as they finish evaluation.
+
+To access this feature, you need to install Gradio.
+:::{note}
+Gradio requires Python 3.7+. Make sure to install Python 3.7+ to use this tool.
+:::
+```console
+pip install gradio
+```
+
+Additionally, you can optionally install `pydot` and `graphviz`. This will allow this tool to incorporate the complementary [graphical illustration](pydot-visualize-dag) of the nodes and edges.
+::::{tabbed} MacOS
+```
+pip install -U pydot && brew install graphviz
+```
+::::
+
+::::{tabbed} Windows
+```
+pip install -U pydot && winget install graphviz
+```
+::::
+
+::::{tabbed} Linux
+```
+pip install -U pydot && sudo apt-get install -y graphviz
+```
+::::
+
+
+Also, for the [quickstart example](gradio-vis-quickstart), install the `transformers` module to pull models through [HuggingFace's Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines).
+```console
+pip install transformers
+```
+
+(gradio-vis-quickstart)=
+## Quickstart Example
+
+Let's build and visualize a deployment graph that
+ 1. Downloads an image
+ 2. Classifies the image
+ 3. Translates the results to German.
+
+This will be the graphical structure of our deployment graph:
+
+
+
+Open up a new file named `demo.py`. First, let's take care of imports:
+```{literalinclude} ../doc_code/gradio_dag_visualize.py
+:start-after: __doc_import_begin__
+:end-before: __doc_import_end__
+:language: python
+```
+
+### Defining Nodes
+
+The `downloader` function takes an image's URL, downloads it, and returns the image in the form of an `ImageFile`.
+```{literalinclude} ../doc_code/gradio_dag_visualize.py
+:start-after: __doc_downloader_begin__
+:end-before: __doc_downloader_end__
+:language: python
+```
+
+The `ImageClassifier` class, upon initialization, loads the `google/vit-base-patch16-224` image classification model using the Transformers pipeline. Its `classify` method takes in an `ImageFile`, runs the model on it, and outputs the classification labels and scores.
+```{literalinclude} ../doc_code/gradio_dag_visualize.py
+:start-after: __doc_classifier_begin__
+:end-before: __doc_classifier_end__
+:language: python
+```
+
+The `Translator` class, upon initialization, loads the `t5-small` translation model that translates from English to German. Its `translate` method takes in a map from strings to floats, and translates each of its string keys to German.
+```{literalinclude} ../doc_code/gradio_dag_visualize.py
+:start-after: __doc_translator_begin__
+:end-before: __doc_translator_end__
+:language: python
+```
+
+### Building the Graph
+
+Finally, we can build our graph by defining dependencies between nodes.
+```{literalinclude} ../doc_code/gradio_dag_visualize.py
+:start-after: __doc_build_graph_begin__
+:end-before: __doc_build_graph_end__
+:language: python
+```
+
+### Deploy and Execute
+
+Let's deploy and run the deployment graph! Deploy the graph with `serve run` and turn on the visualization with the `gradio` flag:
+```console
+serve run demo:serve_entrypoint --gradio
+```
+
+If you go to `http://localhost:7860`, you can now access the Gradio visualization! Type in a link to an image, click "Run", and you can see all of the intermediate outputs of your graph, including the final output!
+
+
+## Setting Up the Visualization
+Now let's see how to set up this visualization tool.
+
+### Requirement: Driver
+
+The `DAGDriver` is required for the visualization. If the `DAGDriver` is not already part of your deployment graph, you can include it with:
+```python
+new_root_node = DAGDriver.bind(old_root_node)
+```
+
+### Ensure Output Data is Properly Displayed
+
+Since the Gradio UI is set at deploy time, the type of Gradio component used to display intermediate outputs of the graph is also statically determined from the graph deployed. It is important that the correct Gradio component is used for each graph node.
+
+The developer simply needs to specify the return type annotation of each function or method in the deployment graph.
+:::{note}
+If no return type annotation is specified for a node, then the Gradio component for that node will default to a [Gradio Textbox](https://gradio.app/docs/#textbox).
+:::
+
+The following table lists the supported data types and which Gradio component they're displayed on.
+
+:::{list-table}
+:widths: 60 40
+:header-rows: 1
+* - Data Type
+ - Gradio component
+* - `int`, `float`
+ - [Numeric field](https://gradio.app/docs/#number)
+* - `str`
+ - [Textbox](https://gradio.app/docs/#textbox)
+* - `bool`
+ - [Checkbox](https://gradio.app/docs/#checkbox)
+* - `pd.Dataframe`
+ - [DataFrame](https://gradio.app/docs/#dataframe)
+* - `list`, `dict`, `np.ndarray`
+ - [JSON field](https://gradio.app/docs/#json)
+* - `PIL.Image`, `torch.Tensor`
+ - [Image](https://gradio.app/docs/#image)
+:::
+
+For instance, the output of the following function node will be displayed through a [Gradio Checkbox](https://gradio.app/docs/#textbox).
+```python
+@serve.deployment
+def is_valid(begin, end) -> bool:
+ return begin <= end
+```
+
+
+
+
+### Providing Input
+
+Similarly, the Gradio component used for each graph input should also be correct. For instance, a deployment graph for image classification could either take an image URL from which it downloads the image, or take the image directly as input. In the first case, the Gradio UI should allow users to input the URL through a textbox, but in the second case, the Gradio UI should allow users to upload the image through an Image component.
+
+The data type of each user input can be specified by passing in `input_type` to `InputNode()`. The following two sections will describe the two supported ways to provide input through `input_type`.
+
+The following table describes the supported input data types and which Gradio component is used to collect that data.
+
+:::{list-table}
+:widths: 60 40
+:header-rows: 1
+* - Data Type
+ - Gradio component
+* - `int`, `float`
+ - [Numeric field](https://gradio.app/docs/#number)
+* - `str`
+ - [Textbox](https://gradio.app/docs/#textbox)
+* - `bool`
+ - [Checkbox](https://gradio.app/docs/#checkbox)
+* - `pd.Dataframe`
+ - [DataFrame](https://gradio.app/docs/#dataframe)
+* - `PIL.Image`, `torch.Tensor`
+ - [Image](https://gradio.app/docs/#image)
+:::
+
+#### Single Input
+
+If there is a single input to the deployment graph, it can be provided directly through `InputNode`. The following is an example code snippet.
+
+```python
+with InputNode(input_type=ImageFile) as user_input:
+ f_node = f.bind(user_input)
+```
+
+:::{note}
+Notice there is a single input, which is stored in `user_input` (an instance of `InputNode`). The data type of this single input must be one of the supported input data types.
+:::
+When initializating `InputNode()`, the data type can be specified by passing in a `type` variable to the parameter `input_type`. Here, the type is specified to be `ImageFile`, so the Gradio visualization will take in user input through an [Image component](https://gradio.app/docs/#image).
+
+
+#### Multiple Inputs
+
+If there are multiple inputs to the deployment graph, it can be provided by accessing attributes of `InputNode`. The following is an example code snippet.
+
+```python
+with InputNode(input_type={0: int, 1: str, "id": str}) as user_input:
+ f_node = f.bind(user_input[0])
+ g_node = g.bind(user_input[1], user_input["id"])
+```
+
+:::{note}
+Notice there are multiple inputs: `user_input[0]`, `user_input[1]`, and `user_input["id"]`. They are accessed by indexing into `user_input`. The data types for each of these inputs must be one of the supported input data types.
+:::
+
+When initializing `InputNode()`, these data types can be specified by passing in a dictionary that maps key to `type` (where key is integer or string) to the parameter `input_type`. Here, the input types are specified to be `int`, `str`, and `str`, so the Gradio visualization will take in the three inputs through one [Numeric Field](https://gradio.app/docs/#number) and two [Textboxes](https://gradio.app/docs/#textbox).
+
+
\ No newline at end of file
diff --git a/doc/source/serve/tutorials/index.md b/doc/source/serve/tutorials/index.md
index 6986b30788163..038b2b2ede1ea 100644
--- a/doc/source/serve/tutorials/index.md
+++ b/doc/source/serve/tutorials/index.md
@@ -13,5 +13,6 @@ serve-ml-models
batch
rllib
gradio-integration
+gradio-dag-visualization
java
```
diff --git a/python/ray/_private/gcs_utils.py b/python/ray/_private/gcs_utils.py
index 55558332578f2..71cbc14dddbd9 100644
--- a/python/ray/_private/gcs_utils.py
+++ b/python/ray/_private/gcs_utils.py
@@ -577,3 +577,40 @@ def use_gcs_for_bootstrap():
This function is included for the purposes of backwards compatibility.
"""
return True
+
+
+def cleanup_redis_storage(
+ host: str, port: int, password: str, use_ssl: bool, storage_namespace: str
+):
+ """This function is used to cleanup the storage. Before we having
+ a good design for storage backend, it can be used to delete the old
+ data. It support redis cluster and non cluster mode.
+
+ Args:
+ host: The host address of the Redis.
+ port: The port of the Redis.
+ password: The password of the Redis.
+ use_ssl: Whether to encrypt the connection.
+ storage_namespace: The namespace of the storage to be deleted.
+ """
+
+ from ray._raylet import del_key_from_storage # type: ignore
+
+ if not isinstance(host, str):
+ raise ValueError("Host must be a string")
+
+ if not isinstance(password, str):
+ raise ValueError("Password must be a string")
+
+ if port < 0:
+ raise ValueError(f"Invalid port: {port}")
+
+ if not isinstance(use_ssl, bool):
+ raise TypeError("use_ssl must be a boolean")
+
+ if not isinstance(storage_namespace, str):
+ raise ValueError("storage namespace must be a string")
+
+ # Right now, GCS store all data into a hash set key by storage_namespace.
+ # So we only need to delete the specific key to cleanup the cluster.
+ return del_key_from_storage(host, port, password, use_ssl, storage_namespace)
diff --git a/python/ray/_private/import_thread.py b/python/ray/_private/import_thread.py
index 271aace307983..1da268c39e24e 100644
--- a/python/ray/_private/import_thread.py
+++ b/python/ray/_private/import_thread.py
@@ -7,6 +7,7 @@
import ray
import ray._private.profiling as profiling
+from ray import JobID
from ray import cloudpickle as pickle
from ray._private import ray_constants
@@ -42,26 +43,33 @@ def __init__(self, worker, mode, threads_stopped):
self.num_imported = 0
# Protect writes to self.num_imported.
self._lock = threading.Lock()
+ # Protect start and join of import thread.
+ self._thread_spawn_lock = threading.Lock()
# Try to load all FunctionsToRun so that these functions will be
# run before accepting tasks.
self._do_importing()
def start(self):
"""Start the import thread."""
- self.t = threading.Thread(target=self._run, name="ray_import_thread")
- # Making the thread a daemon causes it to exit
- # when the main thread exits.
- self.t.daemon = True
- self.t.start()
+ with self._thread_spawn_lock:
+ if self.t is not None:
+ return
+ self.t = threading.Thread(target=self._run, name="ray_import_thread")
+ # Making the thread a daemon causes it to exit
+ # when the main thread exits.
+ self.t.daemon = True
+ self.t.start()
def join_import_thread(self):
"""Wait for the thread to exit."""
- if self.t:
- self.t.join()
+ with self._thread_spawn_lock:
+ if self.t:
+ self.t.join()
def _run(self):
try:
- self._do_importing()
+ if not self.threads_stopped.is_set():
+ self._do_importing()
while True:
# Exit if we received a signal that we should stop.
if self.threads_stopped.is_set():
@@ -79,10 +87,14 @@ def _run(self):
self.subscriber.close()
def _do_importing(self):
+ job_id = self.worker.current_job_id
+ if job_id == JobID.nil():
+ return
+
while True:
with self._lock:
export_key = ray._private.function_manager.make_export_key(
- self.num_imported + 1, self.worker.current_job_id
+ self.num_imported + 1, job_id
)
key = self.gcs_client.internal_kv_get(
export_key, ray_constants.KV_NAMESPACE_FUNCTION_TABLE
diff --git a/python/ray/_private/runtime_env/_clonevirtualenv.py b/python/ray/_private/runtime_env/_clonevirtualenv.py
index 269a0e4487813..1f2eab3d10403 100644
--- a/python/ray/_private/runtime_env/_clonevirtualenv.py
+++ b/python/ray/_private/runtime_env/_clonevirtualenv.py
@@ -15,7 +15,7 @@
__version__ = "0.5.7"
-logger = logging.getLogger(__name__)
+logger = logging.getLogger()
env_bin_dir = "bin"
diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py
index 2e9a8e90a44fb..e85fc528d52c8 100644
--- a/python/ray/_private/services.py
+++ b/python/ray/_private/services.py
@@ -1129,6 +1129,7 @@ def start_api_server(
# TODO(sang): Modules like job or state APIs should be
# loaded although dashboard is disabled. Fix it.
command.append("--modules-to-load=UsageStatsHead")
+ command.append("--disable-frontend")
process_info = start_ray_process(
command,
@@ -1225,7 +1226,7 @@ def read_log(filename, lines_to_read):
# Is it reachable?
raise Exception("Failed to start a dashboard.")
- if minimal:
+ if minimal or not include_dashboard:
# If it is the minimal installation, the web url (dashboard url)
# shouldn't be configured because it doesn't start a server.
dashboard_url = ""
diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py
index be924ca9241d6..d784c22fc77d8 100644
--- a/python/ray/_private/worker.py
+++ b/python/ray/_private/worker.py
@@ -2097,12 +2097,19 @@ def connect(
" and will be removed in the future."
)
- # Start the import thread
+ # Setup import thread and start the import thread
+ # if the worker has job_id initialized.
+ # Otherwise, defer the start up of
+ # import thread until job_id is initialized.
+ # (python/ray/_raylet.pyx maybe_initialize_job_config)
if mode not in (RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
worker.import_thread = import_thread.ImportThread(
worker, mode, worker.threads_stopped
)
- if ray._raylet.Config.start_python_importer_thread():
+ if (
+ worker.current_job_id != JobID.nil()
+ and ray._raylet.Config.start_python_importer_thread()
+ ):
worker.import_thread.start()
# If this is a driver running in SCRIPT_MODE, start a thread to print error
@@ -2194,6 +2201,19 @@ def disconnect(exiting_interpreter=False):
ray_actor._ActorClassMethodMetadata.reset_cache()
+def start_import_thread():
+ """Start the import thread if the worker is connected."""
+ worker = global_worker
+ worker.check_connected()
+
+ assert _mode() not in (
+ RESTORE_WORKER_MODE,
+ SPILL_WORKER_MODE,
+ ), "import thread can not be used in IO workers."
+ if worker.import_thread and ray._raylet.Config.start_python_importer_thread():
+ worker.import_thread.start()
+
+
@contextmanager
def _changeproctitle(title, next_title):
if _mode() is not LOCAL_MODE:
diff --git a/python/ray/_raylet.pyi b/python/ray/_raylet.pyi
index 691620b277176..0c5c78362557f 100644
--- a/python/ray/_raylet.pyi
+++ b/python/ray/_raylet.pyi
@@ -3,9 +3,9 @@ from typing import Any, Awaitable, TypeVar
R = TypeVar("R")
-class ObjectRef(Awaitable[R]):
+class ObjectRef(Awaitable[R]): # type: ignore
pass
-class ObjectID(Awaitable[R]):
+class ObjectID(Awaitable[R]): # type: ignore
pass
diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx
index 1e2dccb927bb2..15e25be41a291 100644
--- a/python/ray/_raylet.pyx
+++ b/python/ray/_raylet.pyx
@@ -109,7 +109,7 @@ from ray.includes.libcoreworker cimport (
from ray.includes.ray_config cimport RayConfig
from ray.includes.global_state_accessor cimport CGlobalStateAccessor
-
+from ray.includes.global_state_accessor cimport RedisDelKeySync
from ray.includes.optional cimport (
optional
)
@@ -166,6 +166,7 @@ current_task_id = None
current_task_id_lock = threading.Lock()
job_config_initialized = False
+job_config_initialization_lock = threading.Lock()
class ObjectRefGenerator:
@@ -745,8 +746,8 @@ cdef void execute_task(
function_descriptor = CFunctionDescriptorToPython(
ray_function.GetFunctionDescriptor())
function_name = execution_info.function_name
- extra_data = (b'{"name": ' + function_name.encode("ascii") +
- b' "task_id": ' + task_id.hex().encode("ascii") + b'}')
+ extra_data = (b'{"name": "' + function_name.encode("ascii") +
+ b'", "task_id": "' + task_id.hex().encode("ascii") + b'"}')
name_of_concurrency_group_to_execute = \
c_name_of_concurrency_group_to_execute.decode("ascii")
@@ -1389,35 +1390,40 @@ cdef void unhandled_exception_handler(const CRayObject& error) nogil:
def maybe_initialize_job_config():
- global job_config_initialized
- if job_config_initialized:
- return
- # Add code search path to sys.path, set load_code_from_local.
- core_worker = ray._private.worker.global_worker.core_worker
- code_search_path = core_worker.get_job_config().code_search_path
- load_code_from_local = False
- if code_search_path:
- load_code_from_local = True
- for p in code_search_path:
- if os.path.isfile(p):
- p = os.path.dirname(p)
- sys.path.insert(0, p)
- ray._private.worker.global_worker.set_load_code_from_local(load_code_from_local)
-
- # Add driver's system path to sys.path
- py_driver_sys_path = core_worker.get_job_config().py_driver_sys_path
- if py_driver_sys_path:
- for p in py_driver_sys_path:
- sys.path.insert(0, p)
- job_config_initialized = True
-
- # Record the task name via :task_name: magic token in the log file.
- # This is used for the prefix in driver logs `(task_name pid=123) ...`
- job_id_magic_token = "{}{}\n".format(
- ray_constants.LOG_PREFIX_JOB_ID, core_worker.get_current_job_id().hex())
- # Print on both .out and .err
- print(job_id_magic_token, end="")
- print(job_id_magic_token, file=sys.stderr, end="")
+ with job_config_initialization_lock:
+ global job_config_initialized
+ if job_config_initialized:
+ return
+ # Add code search path to sys.path, set load_code_from_local.
+ core_worker = ray._private.worker.global_worker.core_worker
+ code_search_path = core_worker.get_job_config().code_search_path
+ load_code_from_local = False
+ if code_search_path:
+ load_code_from_local = True
+ for p in code_search_path:
+ if os.path.isfile(p):
+ p = os.path.dirname(p)
+ sys.path.insert(0, p)
+ ray._private.worker.global_worker.set_load_code_from_local(load_code_from_local)
+
+ # Add driver's system path to sys.path
+ py_driver_sys_path = core_worker.get_job_config().py_driver_sys_path
+ if py_driver_sys_path:
+ for p in py_driver_sys_path:
+ sys.path.insert(0, p)
+
+ # Record the task name via :task_name: magic token in the log file.
+ # This is used for the prefix in driver logs `(task_name pid=123) ...`
+ job_id_magic_token = "{}{}\n".format(
+ ray_constants.LOG_PREFIX_JOB_ID, core_worker.get_current_job_id().hex())
+ # Print on both .out and .err
+ print(job_id_magic_token, end="")
+ print(job_id_magic_token, file=sys.stderr, end="")
+
+ # Only start import thread after job_config is initialized
+ ray._private.worker.start_import_thread()
+
+ job_config_initialized = True
# This function introduces ~2-7us of overhead per call (i.e., it can be called
@@ -2787,3 +2793,7 @@ cdef void async_callback(shared_ptr[CRayObject] obj,
py_callback = user_callback
py_callback(result)
cpython.Py_DECREF(py_callback)
+
+
+def del_key_from_storage(host, port, password, use_ssl, key):
+ return RedisDelKeySync(host, port, password, use_ssl, key)
diff --git a/python/ray/air/checkpoint.py b/python/ray/air/checkpoint.py
index 2f0e82cbf7d83..07a2ee7734c13 100644
--- a/python/ray/air/checkpoint.py
+++ b/python/ray/air/checkpoint.py
@@ -116,7 +116,7 @@ class Checkpoint:
When converting between different checkpoint formats, it is guaranteed
that a full round trip of conversions (e.g. directory --> dict -->
- obj ref --> directory) will recover the original checkpoint data.
+ --> directory) will recover the original checkpoint data.
There are no guarantees made about compatibility of intermediate
representations.
@@ -142,10 +142,7 @@ class Checkpoint:
same node or a node that also has access to the local data path (e.g.
on a shared file system like NFS).
- Checkpoints pointing to object store references will keep the
- object reference in tact - this means that these checkpoints cannot
- be properly deserialized on other Ray clusters or outside a Ray
- cluster. If you need persistence across clusters, use the ``to_uri()``
+ If you need persistence across clusters, use the ``to_uri()``
or ``to_directory()`` methods to persist your checkpoints to disk.
"""
@@ -165,7 +162,6 @@ def __init__(
local_path: Optional[Union[str, os.PathLike]] = None,
data_dict: Optional[dict] = None,
uri: Optional[str] = None,
- obj_ref: Optional[ray.ObjectRef] = None,
):
# First, resolve file:// URIs to local paths
if uri:
@@ -175,7 +171,7 @@ def __init__(
# Only one data type can be set at any time
if local_path:
- assert not data_dict and not uri and not obj_ref
+ assert not data_dict and not uri
if not isinstance(local_path, (str, os.PathLike)) or not os.path.exists(
local_path
):
@@ -191,21 +187,14 @@ def __init__(
f"instead."
)
elif data_dict:
- assert not local_path and not uri and not obj_ref
+ assert not local_path and not uri
if not isinstance(data_dict, dict):
raise RuntimeError(
f"Cannot create checkpoint from dict as no "
f"dict was passed: {data_dict}"
)
- elif obj_ref:
- assert not local_path and not data_dict and not uri
- if not isinstance(obj_ref, ray.ObjectRef):
- raise RuntimeError(
- f"Cannot create checkpoint from object ref as no "
- f"object ref was passed: {obj_ref}"
- )
elif uri:
- assert not local_path and not data_dict and not obj_ref
+ assert not local_path and not data_dict
resolved = _get_external_path(uri)
if not resolved:
raise RuntimeError(
@@ -221,7 +210,6 @@ def __init__(
)
self._data_dict: Optional[Dict[str, Any]] = data_dict
self._uri: Optional[str] = uri
- self._obj_ref: Optional[ray.ObjectRef] = obj_ref
self._override_preprocessor: Optional["Preprocessor"] = None
self._uuid = uuid.uuid4()
@@ -349,9 +337,6 @@ def to_dict(self) -> dict:
if self._data_dict:
# If the checkpoint data is already a dict, return
checkpoint_data = self._data_dict
- elif self._obj_ref:
- # If the checkpoint data is an object reference, resolve
- checkpoint_data = ray.get(self._obj_ref)
elif self._local_path or self._uri:
# Else, checkpoint is either on FS or external storage
with self.as_directory() as local_path:
@@ -462,7 +447,6 @@ def from_checkpoint(cls, other: "Checkpoint") -> "Checkpoint":
local_path=other._local_path,
data_dict=other._data_dict,
uri=other._uri,
- obj_ref=other._obj_ref,
)
new_checkpoint._copy_metadata_attrs_from(other)
return new_checkpoint
@@ -497,8 +481,7 @@ def _save_checkpoint_metadata_in_directory(self, path: str) -> None:
pickle.dump(self._metadata, file)
def _to_directory(self, path: str, move_instead_of_copy: bool = False) -> None:
- if self._data_dict or self._obj_ref:
- # This is a object ref or dict
+ if self._data_dict:
data_dict = self.to_dict()
if _FS_CHECKPOINT_KEY in data_dict:
for key in data_dict.keys():
@@ -741,7 +724,7 @@ def get_internal_representation(
objects for equality or to access the underlying data storage.
The returned type is a string and one of
- ``["local_path", "data_dict", "uri", "object_ref"]``.
+ ``["local_path", "data_dict", "uri"]``.
The data is the respective data value.
@@ -757,8 +740,6 @@ def get_internal_representation(
return "data_dict", self._data_dict
elif self._uri:
return "uri", self._uri
- elif self._obj_ref:
- return "object_ref", self._obj_ref
else:
raise RuntimeError(
"Cannot get internal representation of empty checkpoint."
diff --git a/python/ray/air/integrations/wandb.py b/python/ray/air/integrations/wandb.py
index 54e9571463e27..bd52e3c3ff85f 100644
--- a/python/ray/air/integrations/wandb.py
+++ b/python/ray/air/integrations/wandb.py
@@ -355,8 +355,8 @@ class _QueueItem(enum.Enum):
class _WandbLoggingActor:
"""
- We need a separate process to allow multiple concurrent
- wandb logging instances locally. We use Ray actors as forking multiprocessing
+ Wandb assumes that each trial's information should be logged from a
+ separate process. We use Ray actors as forking multiprocessing
processes is not supported by Ray and spawn processes run into pickling
problems.
diff --git a/python/ray/air/tests/mocked_wandb_integration.py b/python/ray/air/tests/mocked_wandb_integration.py
new file mode 100644
index 0000000000000..92e3ef1d27447
--- /dev/null
+++ b/python/ray/air/tests/mocked_wandb_integration.py
@@ -0,0 +1,117 @@
+from collections import namedtuple
+from queue import Queue
+import threading
+from unittest.mock import Mock
+from wandb.util import json_dumps_safer
+
+from ray.air.integrations.wandb import (
+ _WandbLoggingActor,
+ _QueueItem,
+ WandbLoggerCallback,
+)
+
+
+class Trial(
+ namedtuple(
+ "MockTrial",
+ [
+ "config",
+ "trial_id",
+ "trial_name",
+ "experiment_dir_name",
+ "placement_group_factory",
+ "logdir",
+ ],
+ )
+):
+ def __hash__(self):
+ return hash(self.trial_id)
+
+ def __str__(self):
+ return self.trial_name
+
+
+class _FakeConfig:
+ """Thread-safe."""
+
+ def __init__(self):
+ self.queue = Queue()
+
+ # This is called during both on_trial_start and on_trial_result.
+ def update(self, config, *args, **kwargs):
+ self.queue.put(config)
+
+
+class _MockWandbAPI:
+ """Thread-safe.
+
+ Note: Not implemented to mock re-init behavior properly. Proceed with caution."""
+
+ def __init__(self):
+ self.logs = Queue()
+ self.config = _FakeConfig()
+
+ def init(self, *args, **kwargs):
+ mock = Mock()
+ mock.args = args
+ mock.kwargs = kwargs
+
+ if "config" in kwargs:
+ self.config.update(kwargs["config"])
+
+ return mock
+
+ def log(self, data):
+ try:
+ json_dumps_safer(data)
+ except Exception:
+ self.logs.put("serialization error")
+ else:
+ self.logs.put(data)
+
+ def finish(self):
+ pass
+
+
+class _MockWandbLoggingActor(_WandbLoggingActor):
+ def __init__(self, logdir, queue, exclude, to_config, *args, **kwargs):
+ super(_MockWandbLoggingActor, self).__init__(
+ logdir, queue, exclude, to_config, *args, **kwargs
+ )
+ self._wandb = _MockWandbAPI()
+
+
+class WandbTestExperimentLogger(WandbLoggerCallback):
+
+ """Wandb logger with mocked Wandb API gateway (one per trial)."""
+
+ @property
+ def trial_processes(self):
+ return self._trial_logging_actors
+
+ def _start_logging_actor(self, trial, exclude_results, **wandb_init_kwargs):
+ self._trial_queues[trial] = Queue()
+ local_actor = _MockWandbLoggingActor(
+ logdir=trial.logdir,
+ queue=self._trial_queues[trial],
+ exclude=exclude_results,
+ to_config=self.AUTO_CONFIG_KEYS,
+ **wandb_init_kwargs,
+ )
+ self._trial_logging_actors[trial] = local_actor
+
+ thread = threading.Thread(target=local_actor.run)
+ self._trial_logging_futures[trial] = thread
+ thread.start()
+
+ def _stop_logging_actor(self, trial: "Trial", timeout: int = 10):
+ # Unique for the mocked instance is the delayed delete of
+ # `self._trial_logging_actors`.
+ # This is because we want to access them in unit test after `.fit()`
+ # to assert certain config and log is called with wandb.
+ if trial in self._trial_queues:
+ self._trial_queues[trial].put((_QueueItem.END, None))
+ del self._trial_queues[trial]
+ if trial in self._trial_logging_futures:
+ self._trial_logging_futures[trial].join(timeout=2)
+ del self._trial_logging_futures[trial]
diff --git a/python/ray/air/tests/test_integration_wandb.py b/python/ray/air/tests/test_integration_wandb.py
index a8888b48106b4..e7f43552f5436 100644
--- a/python/ray/air/tests/test_integration_wandb.py
+++ b/python/ray/air/tests/test_integration_wandb.py
@@ -1,10 +1,6 @@
import os
import tempfile
-import threading
-from collections import namedtuple
-from dataclasses import dataclass
-from queue import Queue
-from typing import Tuple, Dict
+
from unittest.mock import (
Mock,
patch,
@@ -14,12 +10,12 @@
import pytest
import ray
+
from ray.tune import Trainable
+from ray.tune.integration.wandb import WandbTrainableMixin
+
from ray.tune.trainable import wrap_function
-from ray.tune.integration.wandb import (
- WandbTrainableMixin,
- wandb_mixin,
-)
+from ray.tune.integration.wandb import wandb_mixin
from ray.air.integrations.wandb import (
WandbLoggerCallback,
_QueueItem,
@@ -35,106 +31,13 @@
from ray.tune.result import TRIAL_INFO
from ray.tune.experiment.trial import _TrialInfo
from ray.tune.execution.placement_groups import PlacementGroupFactory
-from wandb.util import json_dumps_safer
-
-
-class Trial(
- namedtuple(
- "MockTrial",
- [
- "config",
- "trial_id",
- "trial_name",
- "experiment_dir_name",
- "placement_group_factory",
- "logdir",
- ],
- )
-):
- def __hash__(self):
- return hash(self.trial_id)
-
- def __str__(self):
- return self.trial_name
-
-
-@dataclass
-class _MockWandbConfig:
- args: Tuple
- kwargs: Dict
-
-
-class _FakeConfig:
- def update(self, config, *args, **kwargs):
- for key, value in config.items():
- setattr(self, key, value)
-
- def __iter__(self):
- return iter(self.__dict__)
-
-
-class _MockWandbAPI:
- def __init__(self):
- self.logs = Queue()
- self.config = _FakeConfig()
-
- def init(self, *args, **kwargs):
- mock = Mock()
- mock.args = args
- mock.kwargs = kwargs
- if "config" in kwargs:
- self.config.update(kwargs["config"])
-
- return mock
-
- def log(self, data):
- try:
- json_dumps_safer(data)
- except Exception:
- self.logs.put("serialization error")
- else:
- self.logs.put(data)
-
- def finish(self):
- pass
-
-
-class _MockWandbLoggingActor(_WandbLoggingActor):
- def __init__(self, logdir, queue, exclude, to_config, *args, **kwargs):
- super(_MockWandbLoggingActor, self).__init__(
- logdir, queue, exclude, to_config, *args, **kwargs
- )
- self._wandb = _MockWandbAPI()
-
-
-class WandbTestExperimentLogger(WandbLoggerCallback):
- @property
- def trial_processes(self):
- return self._trial_logging_actors
-
- def _start_logging_actor(self, trial, exclude_results, **wandb_init_kwargs):
- self._trial_queues[trial] = Queue()
- local_actor = _MockWandbLoggingActor(
- logdir=trial.logdir,
- queue=self._trial_queues[trial],
- exclude=exclude_results,
- to_config=self.AUTO_CONFIG_KEYS,
- **wandb_init_kwargs,
- )
- self._trial_logging_actors[trial] = local_actor
-
- thread = threading.Thread(target=local_actor.run)
- self._trial_logging_futures[trial] = thread
- thread.start()
-
- def _stop_logging_actor(self, trial: "Trial", timeout: int = 10):
- self._trial_queues[trial].put((_QueueItem.END, None))
-
- del self._trial_queues[trial]
- del self._trial_logging_actors[trial]
- self._trial_logging_futures[trial].join(timeout=2)
- del self._trial_logging_futures[trial]
+from ray.air.tests.mocked_wandb_integration import (
+ _MockWandbAPI,
+ _MockWandbLoggingActor,
+ Trial,
+ WandbTestExperimentLogger,
+)
class _MockWandbTrainableMixin(WandbTrainableMixin):
@@ -372,10 +275,14 @@ def test_wandb_logger_reporting(self, trial):
def test_wandb_logger_auto_config_keys(self, trial):
logger = WandbTestExperimentLogger(project="test_project", api_key="1234")
logger.on_trial_start(iteration=0, trials=[], trial=trial)
- config = logger.trial_processes[trial]._wandb.config
+ config = logger.trial_processes[trial]._wandb.config.queue.get(timeout=10)
result = {key: 0 for key in WandbLoggerCallback.AUTO_CONFIG_KEYS}
logger.on_trial_result(0, [], trial, result)
+ config_increment = logger.trial_processes[trial]._wandb.config.queue.get(
+ timeout=10
+ )
+ config.update(config_increment)
logger.on_trial_complete(0, [], trial)
# The results in `AUTO_CONFIG_KEYS` should be saved as training configuration
@@ -397,11 +304,15 @@ def test_wandb_logger_exclude_config(self):
excludes=(["param2"] + WandbLoggerCallback.AUTO_CONFIG_KEYS),
)
logger.on_trial_start(iteration=0, trials=[], trial=trial)
- config = logger.trial_processes[trial]._wandb.config
+ config = logger.trial_processes[trial]._wandb.config.queue.get(timeout=10)
# We need to test that `excludes` also applies to `AUTO_CONFIG_KEYS`.
result = {key: 0 for key in WandbLoggerCallback.AUTO_CONFIG_KEYS}
logger.on_trial_result(0, [], trial, result)
+ config_increment = logger.trial_processes[trial]._wandb.config.queue.get(
+ timeout=10
+ )
+ config.update(config_increment)
logger.on_trial_complete(0, [], trial)
assert set(config) == {"param1"}
diff --git a/python/ray/autoscaler/_private/autoscaler.py b/python/ray/autoscaler/_private/autoscaler.py
index b133924ea922e..ead10d8419502 100644
--- a/python/ray/autoscaler/_private/autoscaler.py
+++ b/python/ray/autoscaler/_private/autoscaler.py
@@ -20,6 +20,7 @@
AUTOSCALER_MAX_CONCURRENT_LAUNCHES,
AUTOSCALER_MAX_LAUNCH_BATCH,
AUTOSCALER_MAX_NUM_FAILURES,
+ AUTOSCALER_STATUS_LOG,
AUTOSCALER_UPDATE_INTERVAL_S,
DISABLE_LAUNCH_CONFIG_CHECK_KEY,
DISABLE_NODE_UPDATERS_KEY,
@@ -414,7 +415,8 @@ def _update(self):
)
# Update status strings
- logger.info(self.info_string())
+ if AUTOSCALER_STATUS_LOG:
+ logger.info(self.info_string())
legacy_log_info_string(self, self.non_terminated_nodes.worker_ids)
if not self.provider.is_readonly():
diff --git a/python/ray/autoscaler/_private/constants.py b/python/ray/autoscaler/_private/constants.py
index fdf38d8467e50..4ace80e91ec7a 100644
--- a/python/ray/autoscaler/_private/constants.py
+++ b/python/ray/autoscaler/_private/constants.py
@@ -20,6 +20,9 @@ def env_integer(key, default):
return default
+# Whether autoscaler cluster status logging is enabled. Set to 0 disable.
+AUTOSCALER_STATUS_LOG = env_integer("RAY_ENABLE_CLUSTER_STATUS_LOG", 1)
+
# The name of the environment variable for plugging in a utilization scorer.
AUTOSCALER_UTILIZATION_SCORER_KEY = "RAY_AUTOSCALER_UTILIZATION_SCORER"
diff --git a/python/ray/autoscaler/_private/kuberay/run_autoscaler.py b/python/ray/autoscaler/_private/kuberay/run_autoscaler.py
index 1d610a655e97f..657f5b2e32427 100644
--- a/python/ray/autoscaler/_private/kuberay/run_autoscaler.py
+++ b/python/ray/autoscaler/_private/kuberay/run_autoscaler.py
@@ -17,21 +17,7 @@
def run_kuberay_autoscaler(cluster_name: str, cluster_namespace: str):
- """Wait until the Ray head container is ready. Then start the autoscaler.
-
- For kuberay's autoscaler integration, the autoscaler runs in a sidecar container
- in the same pod as the main Ray container, which runs the rest of the Ray
- processes.
-
- The logging configuration here is for the sidecar container, but we need the
- logs to go to the same place as the head node logs because the autoscaler is
- allowed to send scaling events to Ray drivers' stdout. The implementation of
- this feature involves the autoscaler communicating to another Ray process
- (the log monitor) via logs in that directory.
-
- However, the Ray head container sets up the log directory. Thus, we set up
- logging only after the Ray head is ready.
- """
+ """Wait until the Ray head container is ready. Then start the autoscaler."""
head_ip = get_node_ip_address()
ray_address = f"{head_ip}:6379"
while True:
@@ -55,6 +41,8 @@ def run_kuberay_autoscaler(cluster_name: str, cluster_namespace: str):
print(f"Will check again in {BACKOFF_S} seconds.")
time.sleep(BACKOFF_S)
+ # The Ray head container sets up the log directory. Thus, we set up logging
+ # only after the Ray head is ready.
_setup_logging()
# autoscaling_config_producer reads the RayCluster CR from K8s and uses the CR
diff --git a/python/ray/autoscaler/_private/local/node_provider.py b/python/ray/autoscaler/_private/local/node_provider.py
index e9bb540b88fc6..cd3b7b64166fa 100644
--- a/python/ray/autoscaler/_private/local/node_provider.py
+++ b/python/ray/autoscaler/_private/local/node_provider.py
@@ -25,7 +25,8 @@
logger = logging.getLogger(__name__)
-logging.getLogger("filelock").setLevel(logging.WARNING)
+filelock_logger = logging.getLogger("filelock")
+filelock_logger.setLevel(logging.WARNING)
class ClusterState:
diff --git a/python/ray/data/_internal/execution/operators/actor_pool_submitter.py b/python/ray/data/_internal/execution/operators/actor_pool_submitter.py
index a98a2cf078d64..23727e907031a 100644
--- a/python/ray/data/_internal/execution/operators/actor_pool_submitter.py
+++ b/python/ray/data/_internal/execution/operators/actor_pool_submitter.py
@@ -28,8 +28,6 @@ def __init__(
ray_remote_args: Remote arguments for the Ray actors to be created.
pool_size: The size of the actor pool.
"""
- if "num_cpus" not in ray_remote_args:
- raise ValueError("Remote args should have explicit CPU spec.")
self._transform_fn_ref = transform_fn_ref
self._ray_remote_args = ray_remote_args
self._pool_size = pool_size
diff --git a/python/ray/data/_internal/execution/operators/all_to_all_operator.py b/python/ray/data/_internal/execution/operators/all_to_all_operator.py
index 166b73319c9b2..9a10adfee11d3 100644
--- a/python/ray/data/_internal/execution/operators/all_to_all_operator.py
+++ b/python/ray/data/_internal/execution/operators/all_to_all_operator.py
@@ -38,7 +38,11 @@ def __init__(
super().__init__(name, [input_op])
def num_outputs_total(self) -> Optional[int]:
- return self._num_outputs
+ return (
+ self._num_outputs
+ if self._num_outputs
+ else self.input_dependencies[0].num_outputs_total()
+ )
def add_input(self, refs: RefBundle, input_index: int) -> None:
assert not self.completed()
diff --git a/python/ray/data/_internal/execution/operators/map_operator.py b/python/ray/data/_internal/execution/operators/map_operator.py
index 7791e9cf56fa6..a0062bfa25c24 100644
--- a/python/ray/data/_internal/execution/operators/map_operator.py
+++ b/python/ray/data/_internal/execution/operators/map_operator.py
@@ -130,12 +130,12 @@ def _canonicalize_ray_remote_args(ray_remote_args: Dict[str, Any]) -> Dict[str,
ray_remote_args = ray_remote_args.copy()
if "num_cpus" not in ray_remote_args and "num_gpus" not in ray_remote_args:
ray_remote_args["num_cpus"] = 1
- if "num_gpus" in ray_remote_args:
+ if ray_remote_args.get("num_gpus", 0) > 0:
if ray_remote_args.get("num_cpus", 0) != 0:
raise ValueError(
"It is not allowed to specify both num_cpus and num_gpus for map tasks."
)
- elif "num_cpus" in ray_remote_args:
+ elif ray_remote_args.get("num_cpus", 0) > 0:
if ray_remote_args.get("num_gpus", 0) != 0:
raise ValueError(
"It is not allowed to specify both num_cpus and num_gpus for map tasks."
diff --git a/python/ray/data/_internal/execution/streaming_executor.py b/python/ray/data/_internal/execution/streaming_executor.py
index fe3dc92f12a48..1d3543100b0cb 100644
--- a/python/ray/data/_internal/execution/streaming_executor.py
+++ b/python/ray/data/_internal/execution/streaming_executor.py
@@ -1,9 +1,12 @@
import logging
+import os
from typing import Iterator, Optional
+import ray
from ray.data._internal.execution.interfaces import (
Executor,
ExecutionOptions,
+ ExecutionResources,
RefBundle,
PhysicalOperator,
)
@@ -15,10 +18,14 @@
process_completed_tasks,
select_operator_to_run,
)
+from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.stats import DatasetStats
logger = logging.getLogger(__name__)
+# Set this environment variable for detailed scheduler debugging logs.
+DEBUG_TRACE_SCHEDULING = "RAY_DATASET_TRACE_SCHEDULING" in os.environ
+
class StreamingExecutor(Executor):
"""A streaming Dataset executor.
@@ -33,6 +40,9 @@ def __init__(self, options: ExecutionOptions):
# object as data is streamed through (similar to how iterating over the output
# data updates the stats object in legacy code).
self._stats: Optional[DatasetStats] = None
+ self._global_info: Optional[ProgressBar] = None
+ if options.locality_with_output:
+ raise NotImplementedError("locality with output")
super().__init__(options)
def execute(
@@ -45,19 +55,29 @@ def execute(
"""
if not isinstance(dag, InputDataBuffer):
logger.info("Executing DAG %s", dag)
+ self._global_info = ProgressBar("Resource usage vs limits", 1, 0)
# Setup the streaming DAG topology.
topology, self._stats = build_streaming_topology(dag, self._options)
output_node: OpState = topology[dag]
- # Run scheduling loop until complete.
- while self._scheduling_loop_step(topology):
+ try:
+ _validate_topology(topology, self._get_or_refresh_resource_limits())
+ output_node: OpState = topology[dag]
+
+ # Run scheduling loop until complete.
+ while self._scheduling_loop_step(topology):
+ while output_node.outqueue:
+ yield output_node.outqueue.pop(0)
+
+ # Handle any leftover outputs.
while output_node.outqueue:
yield output_node.outqueue.pop(0)
-
- # Handle any leftover outputs.
- while output_node.outqueue:
- yield output_node.outqueue.pop(0)
+ finally:
+ for op in topology:
+ op.shutdown()
+ if self._global_info:
+ self._global_info.close()
def get_stats(self):
"""Return the stats object for the streaming execution.
@@ -79,16 +99,110 @@ def _scheduling_loop_step(self, topology: Topology) -> bool:
True if we should continue running the scheduling loop.
"""
+ if DEBUG_TRACE_SCHEDULING:
+ logger.info("Scheduling loop step...")
+
# Note: calling process_completed_tasks() is expensive since it incurs
# ray.wait() overhead, so make sure to allow multiple dispatch per call for
# greater parallelism.
process_completed_tasks(topology)
+ for op_state in topology.values():
+ op_state.refresh_progress_bar()
# Dispatch as many operators as we can for completed tasks.
- op = select_operator_to_run(topology)
+ limits = self._get_or_refresh_resource_limits()
+ cur_usage = self._get_current_usage(topology)
+ self._report_current_usage(cur_usage, limits)
+ op = select_operator_to_run(topology, cur_usage, limits)
while op is not None:
+ if DEBUG_TRACE_SCHEDULING:
+ _debug_dump_topology(topology)
topology[op].dispatch_next_task()
- op = select_operator_to_run(topology)
+ cur_usage = self._get_current_usage(topology)
+ op = select_operator_to_run(topology, cur_usage, limits)
# Keep going until all operators run to completion.
return not all(op.completed() for op in topology)
+
+ def _get_or_refresh_resource_limits(self) -> ExecutionResources:
+ """Return concrete limits for use at the current time.
+
+ This method autodetects any unspecified execution resource limits based on the
+ current cluster size, refreshing these values periodically to support cluster
+ autoscaling.
+ """
+ base = self._options.resource_limits
+ cluster = ray.cluster_resources()
+ return ExecutionResources(
+ cpu=base.cpu if base.cpu is not None else cluster.get("CPU", 0.0),
+ gpu=base.gpu if base.gpu is not None else cluster.get("GPU", 0.0),
+ object_store_memory=base.object_store_memory
+ if base.object_store_memory is not None
+ else cluster.get("object_store_memory", 0.0) // 4,
+ )
+
+ def _get_current_usage(self, topology: Topology) -> ExecutionResources:
+ cur_usage = ExecutionResources()
+ for op, state in topology.items():
+ cur_usage = cur_usage.add(op.current_resource_usage())
+ if isinstance(op, InputDataBuffer):
+ continue # Don't count input refs towards dynamic memory usage.
+ for bundle in state.outqueue:
+ cur_usage.object_store_memory += bundle.size_bytes()
+ return cur_usage
+
+ def _report_current_usage(
+ self, cur_usage: ExecutionResources, limits: ExecutionResources
+ ) -> None:
+ if self._global_info:
+ self._global_info.set_description(
+ "Resource usage vs limits: "
+ f"{cur_usage.cpu}/{limits.cpu} CPU, "
+ f"{cur_usage.gpu}/{limits.gpu} GPU, "
+ f"{cur_usage.object_store_memory_str()}/"
+ f"{limits.object_store_memory_str()} object_store_memory"
+ )
+
+
+def _validate_topology(topology: Topology, limits: ExecutionResources) -> None:
+ """Raises an exception on invalid topologies.
+
+ It checks if the the sum of min actor pool sizes are larger than the resource
+ limit, as well as other unsupported resource configurations.
+
+ Args:
+ topology: The topology to validate.
+ limits: The limits to validate against.
+ """
+
+ base_usage = ExecutionResources(cpu=1)
+ for op in topology:
+ base_usage = base_usage.add(op.base_resource_usage())
+ inc_usage = op.incremental_resource_usage()
+ if inc_usage.cpu and inc_usage.gpu:
+ raise NotImplementedError(
+ "Operator incremental resource usage cannot specify both CPU "
+ "and GPU at the same time, since it may cause deadlock."
+ )
+ elif inc_usage.object_store_memory:
+ raise NotImplementedError(
+ "Operator incremental resource usage must not include memory."
+ )
+
+ if not base_usage.satisfies_limit(limits):
+ raise ValueError(
+ f"The base resource usage of this topology {base_usage} "
+ f"exceeds the execution limits {limits}!"
+ )
+
+
+def _debug_dump_topology(topology: Topology) -> None:
+ """Print out current execution state for the topology for debugging.
+
+ Args:
+ topology: The topology to debug.
+ """
+ logger.info("vvv scheduling trace vvv")
+ for i, (op, state) in enumerate(topology.items()):
+ logger.info(f"{i}: {state.summary_str()}")
+ logger.info("^^^ scheduling trace ^^^")
diff --git a/python/ray/data/_internal/execution/streaming_executor_state.py b/python/ray/data/_internal/execution/streaming_executor_state.py
index 119b9a4f80434..29d416dca5934 100644
--- a/python/ray/data/_internal/execution/streaming_executor_state.py
+++ b/python/ray/data/_internal/execution/streaming_executor_state.py
@@ -3,10 +3,12 @@
This is split out from streaming_executor.py to facilitate better unit testing.
"""
+import math
from typing import Dict, List, Optional
import ray
from ray.data._internal.execution.interfaces import (
+ ExecutionResources,
RefBundle,
PhysicalOperator,
ExecutionOptions,
@@ -64,10 +66,15 @@ def add_output(self, ref: RefBundle) -> None:
def refresh_progress_bar(self) -> None:
"""Update the console with the latest operator progress."""
if self.progress_bar:
- queued = self.num_queued()
- self.progress_bar.set_description(
- f"{self.op.name}: {self.num_active_tasks()} active, {queued} queued"
- )
+ self.progress_bar.set_description(self.summary_str())
+
+ def summary_str(self) -> str:
+ queued = self.num_queued()
+ desc = f"{self.op.name}: {self.num_active_tasks()} active, {queued} queued"
+ suffix = self.op.progress_str()
+ if suffix:
+ desc += f", {suffix}"
+ return desc
def dispatch_next_task(self) -> None:
"""Move a bundle from the operator inqueue to the operator itself."""
@@ -117,8 +124,9 @@ def setup_state(op: PhysicalOperator) -> OpState:
setup_state(dag)
# Create the progress bars starting from the first operator to run.
- # Note that the topology dict is in topological sort order.
- i = 0
+ # Note that the topology dict is in topological sort order. Index zero is reserved
+ # for global progress information.
+ i = 1
for op_state in list(topology.values()):
if not isinstance(op_state.op, InputDataBuffer):
op_state.initialize_progress_bar(i)
@@ -130,8 +138,6 @@ def setup_state(op: PhysicalOperator) -> OpState:
def process_completed_tasks(topology: Topology) -> None:
"""Process any newly completed tasks and update operator state."""
- for op_state in topology.values():
- op_state.refresh_progress_bar()
# Update active tasks.
active_tasks: Dict[ray.ObjectRef, PhysicalOperator] = {}
@@ -170,7 +176,9 @@ def process_completed_tasks(topology: Topology) -> None:
op_state.inputs_done_called = True
-def select_operator_to_run(topology: Topology) -> Optional[PhysicalOperator]:
+def select_operator_to_run(
+ topology: Topology, cur_usage: ExecutionResources, limits: ExecutionResources
+) -> Optional[PhysicalOperator]:
"""Select an operator to run, if possible.
The objective of this function is to maximize the throughput of the overall
@@ -181,17 +189,19 @@ def select_operator_to_run(topology: Topology) -> Optional[PhysicalOperator]:
operators with a large number of running tasks `num_active_tasks()`.
"""
- # TODO: set limits properly based on resources and execution options. This is just
- # a hard-coded development placeholder.
- PARALLELISM_LIMIT = 8
- num_active_tasks = sum(
- op_state.num_active_tasks() for op_state in topology.values()
- )
- if num_active_tasks >= PARALLELISM_LIMIT:
- return None
+ # Filter to ops that are eligible for execution.
+ ops = [
+ op
+ for op, state in topology.items()
+ if state.num_queued() > 0 and _execution_allowed(op, cur_usage, limits)
+ ]
+
+ # To ensure liveness, allow at least 1 op to run regardless of limits.
+ if not ops and all(op.num_active_work_refs() == 0 for op in topology):
+ # The topology is entirely idle, so choose from all ready ops ignoring limits.
+ ops = [op for op, state in topology.items() if state.num_queued() > 0]
- # Filter to ops that have queued inputs.
- ops = [op for op, state in topology.items() if state.num_queued() > 0]
+ # Nothing to run.
if not ops:
return None
@@ -199,3 +209,35 @@ def select_operator_to_run(topology: Topology) -> Optional[PhysicalOperator]:
return min(
ops, key=lambda op: len(topology[op].outqueue) + topology[op].num_active_tasks()
)
+
+
+def _execution_allowed(
+ op: PhysicalOperator,
+ global_usage: ExecutionResources,
+ global_limits: ExecutionResources,
+) -> bool:
+ """Return whether an operator is allowed to execute given resource usage.
+
+ Args:
+ op: The operator to check.
+ global_usage: Resource usage across the entire topology.
+ global_limits: Execution resource limits.
+
+ Returns:
+ Whether the op is allowed to run.
+ """
+ # To avoid starvation problems when dealing with fractional resource types,
+ # convert all quantities to integer (0 or 1) for deciding admissibility. This
+ # allows operators with non-integral requests to slightly overshoot the limit.
+ global_floored = ExecutionResources(
+ cpu=math.floor(global_usage.cpu or 0),
+ gpu=math.floor(global_usage.gpu or 0),
+ object_store_memory=global_usage.object_store_memory,
+ )
+ inc = op.incremental_resource_usage()
+ inc_indicator = ExecutionResources(
+ cpu=1 if inc.cpu else 0,
+ gpu=1 if inc.gpu else 0,
+ object_store_memory=1 if inc.object_store_memory else 0,
+ )
+ return global_floored.add(inc_indicator).satisfies_limit(global_limits)
diff --git a/python/ray/data/_internal/plan.py b/python/ray/data/_internal/plan.py
index 4157bcfc6780f..1741d61d9da1d 100644
--- a/python/ray/data/_internal/plan.py
+++ b/python/ray/data/_internal/plan.py
@@ -454,7 +454,7 @@ def execute_to_iterator(
execute_to_legacy_block_iterator,
)
- executor = StreamingExecutor(ExecutionOptions())
+ executor = StreamingExecutor(ExecutionOptions(preserve_order=False))
block_iter = execute_to_legacy_block_iterator(
executor,
self,
diff --git a/python/ray/data/context.py b/python/ray/data/context.py
index 3a376f822819e..1b735bdaf4066 100644
--- a/python/ray/data/context.py
+++ b/python/ray/data/context.py
@@ -67,7 +67,7 @@
# Whether to use the new executor backend.
DEFAULT_NEW_EXECUTION_BACKEND = bool(
- int(os.environ.get("RAY_DATASET_NEW_EXECUTION_BACKEND", "0"))
+ int(os.environ.get("RAY_DATASET_NEW_EXECUTION_BACKEND", "1"))
)
# Whether to use the streaming executor. This only has an effect if the new execution
diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py
index 132f4aeb957e8..e593c5ca60ce2 100644
--- a/python/ray/data/tests/test_dataset.py
+++ b/python/ray/data/tests/test_dataset.py
@@ -5455,7 +5455,10 @@ def f(x):
.fully_executed()
)
- assert f"{max_size}/{max_size} blocks" in ds.stats()
+ # TODO(https://github.com/ray-project/ray/issues/31723): implement the feature
+ # of capping bundle size by actor pool size, and then re-enable this test.
+ if not DatasetContext.get_current().new_execution_backend:
+ assert f"{max_size}/{max_size} blocks" in ds.stats()
# Check batch size is still respected.
ds = (
diff --git a/python/ray/data/tests/test_dataset_pipeline.py b/python/ray/data/tests/test_dataset_pipeline.py
index b6c3fa0dfb7ba..c3e8a398e4e06 100644
--- a/python/ray/data/tests/test_dataset_pipeline.py
+++ b/python/ray/data/tests/test_dataset_pipeline.py
@@ -814,22 +814,6 @@ def consume(pipe, owned_by_consumer):
ray.get([consume.remote(splits[0], True), consume.remote(splits[1], True)])
-# Run at end of file to avoid segfault https://github.com/ray-project/ray/issues/31145
-def test_incremental_take(shutdown_only):
- ray.shutdown()
- ray.init(num_cpus=2)
-
- # Can read incrementally even if future results are delayed.
- def block_on_ones(x: int) -> int:
- if x == 1:
- time.sleep(999999)
- return x
-
- pipe = ray.data.range(2).window(blocks_per_window=1)
- pipe = pipe.map(block_on_ones)
- assert pipe.take(1) == [0]
-
-
if __name__ == "__main__":
import sys
diff --git a/python/ray/data/tests/test_object_gc.py b/python/ray/data/tests/test_object_gc.py
index 951c97e1cd23d..67c627ce3afc4 100644
--- a/python/ray/data/tests/test_object_gc.py
+++ b/python/ray/data/tests/test_object_gc.py
@@ -126,26 +126,22 @@ def test_iter_batches_no_spilling_upon_shuffle(shutdown_only):
def test_pipeline_splitting_has_no_spilling(shutdown_only):
# The object store is about 800MiB.
- ctx = ray.init(num_cpus=1, object_store_memory=800e6)
+ ctx = ray.init(num_cpus=1, object_store_memory=1200e6)
# The size of dataset is 50000*(80*80*4)*8B, about 10GiB, 50MiB/block.
- ds = ray.data.range_tensor(50000, shape=(80, 80, 4), parallelism=200)
+ ds = ray.data.range_tensor(5000, shape=(80, 80, 4), parallelism=20)
# 2 blocks/window.
- p = ds.window(bytes_per_window=100 * 1024 * 1024).repeat()
+ p = ds.window(bytes_per_window=100 * 1024 * 1024).repeat(2)
p1, p2 = p.split(2)
@ray.remote
def consume(p):
for batch in p.iter_batches(batch_size=None):
pass
+ print(p.stats())
tasks = [consume.remote(p1), consume.remote(p2)]
- try:
- # Run it for 20 seconds.
- ray.get(tasks, timeout=20)
- except Exception:
- for t in tasks:
- ray.cancel(t, force=True)
+ ray.get(tasks)
meminfo = memory_summary(ctx.address_info["address"], stats_only=True)
assert "Spilled" not in meminfo, meminfo
diff --git a/python/ray/data/tests/test_operators.py b/python/ray/data/tests/test_operators.py
index e2c025c8f4c7a..1a9d946cdaa61 100644
--- a/python/ray/data/tests/test_operators.py
+++ b/python/ray/data/tests/test_operators.py
@@ -71,6 +71,22 @@ def dummy_all_transform(bundles: List[RefBundle]):
assert op.completed()
+def test_num_outputs_total():
+ input_op = InputDataBuffer(make_ref_bundles([[i] for i in range(100)]))
+ op1 = MapOperator(
+ _mul2_transform,
+ input_op=input_op,
+ name="TestMapper",
+ )
+ assert op1.num_outputs_total() == 100
+
+ def dummy_all_transform(bundles: List[RefBundle]):
+ return make_ref_bundles([[1, 2], [3, 4]]), {"FooStats": []}
+
+ op2 = AllToAllOperator(dummy_all_transform, input_op=op1, name="TestAll")
+ assert op2.num_outputs_total() == 100
+
+
@pytest.mark.parametrize("use_actors", [False, True])
def test_map_operator_bulk(ray_start_regular_shared, use_actors):
# Create with inputs.
diff --git a/python/ray/data/tests/test_pipeline_incremental_take.py b/python/ray/data/tests/test_pipeline_incremental_take.py
new file mode 100644
index 0000000000000..be8357a78ae60
--- /dev/null
+++ b/python/ray/data/tests/test_pipeline_incremental_take.py
@@ -0,0 +1,31 @@
+import time
+import pytest
+import ray
+from ray.data.context import DatasetContext
+
+from ray.tests.conftest import * # noqa
+
+
+def test_incremental_take(shutdown_only):
+ # TODO(https://github.com/ray-project/ray/issues/31145): re-enable
+ # after the segfault bug is fixed.
+ if DatasetContext.get_current().new_execution_backend:
+ return
+
+ ray.init(num_cpus=2)
+
+ # Can read incrementally even if future results are delayed.
+ def block_on_ones(x: int) -> int:
+ if x == 1:
+ time.sleep(999999)
+ return x
+
+ pipe = ray.data.range(2).window(blocks_per_window=1)
+ pipe = pipe.map(block_on_ones)
+ assert pipe.take(1) == [0]
+
+
+if __name__ == "__main__":
+ import sys
+
+ sys.exit(pytest.main(["-v", __file__]))
diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py
index d33b624259046..c5c85c27ddda7 100644
--- a/python/ray/data/tests/test_stats.py
+++ b/python/ray/data/tests/test_stats.py
@@ -54,7 +54,7 @@ def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
"""
)
@@ -85,7 +85,7 @@ def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
"""
)
@@ -115,7 +115,7 @@ def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
Stage N map: N/N blocks executed in T
@@ -133,7 +133,7 @@ def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
* In format_batch(): T
* In user code: T
* Total time: T
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
"""
)
@@ -266,7 +266,7 @@ def test_dataset_stats_read_parquet(ray_start_regular_shared, tmp_path):
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
"""
)
@@ -302,7 +302,7 @@ def test_dataset_split_stats(ray_start_regular_shared, tmp_path):
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
Stage N split: N/N blocks executed in T
@@ -320,7 +320,7 @@ def test_dataset_split_stats(ray_start_regular_shared, tmp_path):
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
"""
)
@@ -384,7 +384,7 @@ def test_dataset_pipeline_stats_basic(ray_start_regular_shared, enable_auto_log_
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
"""
)
@@ -416,7 +416,7 @@ def test_dataset_pipeline_stats_basic(ray_start_regular_shared, enable_auto_log_
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
"""
)
@@ -451,7 +451,7 @@ def test_dataset_pipeline_stats_basic(ray_start_regular_shared, enable_auto_log_
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
"""
)
@@ -480,7 +480,7 @@ def test_dataset_pipeline_stats_basic(ray_start_regular_shared, enable_auto_log_
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
Stage N map: N/N blocks executed in T
@@ -490,12 +490,12 @@ def test_dataset_pipeline_stats_basic(ray_start_regular_shared, enable_auto_log_
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
== Pipeline Window N ==
Stage N read->map_batches: [execution cached]
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
Stage N map: N/N blocks executed in T
@@ -505,12 +505,12 @@ def test_dataset_pipeline_stats_basic(ray_start_regular_shared, enable_auto_log_
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
== Pipeline Window N ==
Stage N read->map_batches: [execution cached]
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
Stage N map: N/N blocks executed in T
@@ -520,7 +520,7 @@ def test_dataset_pipeline_stats_basic(ray_start_regular_shared, enable_auto_log_
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
##### Overall Pipeline Time Breakdown #####
@@ -629,6 +629,7 @@ def consume(split):
s0, s1 = pipe.split(2)
stats = ray.get([consume.remote(s0), consume.remote(s1)])
if context.new_execution_backend:
+ print("XXX stats:", canonicalize(stats[0]))
assert (
canonicalize(stats[0])
== """== Pipeline Window Z ==
@@ -639,7 +640,7 @@ def consume(split):
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
== Pipeline Window N ==
@@ -650,7 +651,7 @@ def consume(split):
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
-* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': Z, \
+* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
##### Overall Pipeline Time Breakdown #####
diff --git a/python/ray/data/tests/test_streaming_executor.py b/python/ray/data/tests/test_streaming_executor.py
index 6048916eca656..27095867617f7 100644
--- a/python/ray/data/tests/test_streaming_executor.py
+++ b/python/ray/data/tests/test_streaming_executor.py
@@ -1,23 +1,27 @@
import pytest
-import asyncio
import time
from unittest.mock import MagicMock
from typing import List, Any
import ray
-from ray.data.context import DatasetContext
from ray.data._internal.execution.interfaces import (
ExecutionOptions,
+ ExecutionResources,
RefBundle,
PhysicalOperator,
)
-from ray.data._internal.execution.streaming_executor import StreamingExecutor
+from ray.data._internal.execution.streaming_executor import (
+ StreamingExecutor,
+ _debug_dump_topology,
+ _validate_topology,
+)
from ray.data._internal.execution.streaming_executor_state import (
OpState,
build_streaming_topology,
process_completed_tasks,
select_operator_to_run,
+ _execution_allowed,
)
from ray.data._internal.execution.operators.all_to_all_operator import AllToAllOperator
from ray.data._internal.execution.operators.map_operator import MapOperator
@@ -104,28 +108,41 @@ def test_process_completed_tasks():
def test_select_operator_to_run():
+ opt = ExecutionOptions()
inputs = make_ref_bundles([[x] for x in range(20)])
o1 = InputDataBuffer(inputs)
o2 = MapOperator(make_transform(lambda block: [b * -1 for b in block]), o1)
o3 = MapOperator(make_transform(lambda block: [b * 2 for b in block]), o2)
- topo, _ = build_streaming_topology(o3, ExecutionOptions())
+ topo, _ = build_streaming_topology(o3, opt)
# Test empty.
- assert select_operator_to_run(topo) is None
+ assert (
+ select_operator_to_run(topo, ExecutionResources(), ExecutionResources()) is None
+ )
# Test backpressure based on queue length between operators.
topo[o1].outqueue.append("dummy1")
- assert select_operator_to_run(topo) == o2
+ assert (
+ select_operator_to_run(topo, ExecutionResources(), ExecutionResources()) == o2
+ )
topo[o1].outqueue.append("dummy2")
- assert select_operator_to_run(topo) == o2
+ assert (
+ select_operator_to_run(topo, ExecutionResources(), ExecutionResources()) == o2
+ )
topo[o2].outqueue.append("dummy3")
- assert select_operator_to_run(topo) == o3
+ assert (
+ select_operator_to_run(topo, ExecutionResources(), ExecutionResources()) == o3
+ )
# Test backpressure includes num active tasks as well.
topo[o3].num_active_tasks = MagicMock(return_value=2)
- assert select_operator_to_run(topo) == o2
+ assert (
+ select_operator_to_run(topo, ExecutionResources(), ExecutionResources()) == o2
+ )
topo[o2].num_active_tasks = MagicMock(return_value=2)
- assert select_operator_to_run(topo) == o3
+ assert (
+ select_operator_to_run(topo, ExecutionResources(), ExecutionResources()) == o3
+ )
def test_dispatch_next_task():
@@ -148,6 +165,114 @@ def test_dispatch_next_task():
assert o2.add_input.called_once_with("dummy2")
+def test_debug_dump_topology():
+ opt = ExecutionOptions()
+ inputs = make_ref_bundles([[x] for x in range(20)])
+ o1 = InputDataBuffer(inputs)
+ o2 = MapOperator(make_transform(lambda block: [b * -1 for b in block]), o1)
+ o3 = MapOperator(make_transform(lambda block: [b * 2 for b in block]), o2)
+ topo, _ = build_streaming_topology(o3, opt)
+ # Just a sanity check to ensure it doesn't crash.
+ _debug_dump_topology(topo)
+
+
+def test_validate_topology():
+ opt = ExecutionOptions()
+ inputs = make_ref_bundles([[x] for x in range(20)])
+ o1 = InputDataBuffer(inputs)
+ o2 = MapOperator(
+ make_transform(lambda block: [b * -1 for b in block]),
+ o1,
+ compute_strategy=ray.data.ActorPoolStrategy(8, 8),
+ )
+ o3 = MapOperator(
+ make_transform(lambda block: [b * 2 for b in block]),
+ o2,
+ compute_strategy=ray.data.ActorPoolStrategy(4, 4),
+ )
+ topo, _ = build_streaming_topology(o3, opt)
+ _validate_topology(topo, ExecutionResources())
+ _validate_topology(topo, ExecutionResources(cpu=20))
+ _validate_topology(topo, ExecutionResources(gpu=0))
+ with pytest.raises(ValueError):
+ _validate_topology(topo, ExecutionResources(cpu=10))
+
+
+def test_execution_allowed():
+ op = InputDataBuffer([])
+
+ # CPU.
+ op.incremental_resource_usage = MagicMock(return_value=ExecutionResources(cpu=1))
+ assert _execution_allowed(op, ExecutionResources(cpu=1), ExecutionResources(cpu=2))
+ assert not _execution_allowed(
+ op, ExecutionResources(cpu=2), ExecutionResources(cpu=2)
+ )
+ assert _execution_allowed(op, ExecutionResources(cpu=2), ExecutionResources(gpu=2))
+
+ # GPU.
+ op.incremental_resource_usage = MagicMock(
+ return_value=ExecutionResources(cpu=1, gpu=1)
+ )
+ assert _execution_allowed(op, ExecutionResources(gpu=1), ExecutionResources(gpu=2))
+ assert not _execution_allowed(
+ op, ExecutionResources(gpu=2), ExecutionResources(gpu=2)
+ )
+
+ # Test conversion to indicator (0/1).
+ op.incremental_resource_usage = MagicMock(
+ return_value=ExecutionResources(cpu=100, gpu=100)
+ )
+ assert _execution_allowed(op, ExecutionResources(gpu=1), ExecutionResources(gpu=2))
+ assert _execution_allowed(
+ op, ExecutionResources(gpu=1.5), ExecutionResources(gpu=2)
+ )
+ assert not _execution_allowed(
+ op, ExecutionResources(gpu=2), ExecutionResources(gpu=2)
+ )
+
+ # Test conversion to indicator (0/1).
+ op.incremental_resource_usage = MagicMock(
+ return_value=ExecutionResources(cpu=0.1, gpu=0.1)
+ )
+ assert _execution_allowed(op, ExecutionResources(gpu=1), ExecutionResources(gpu=2))
+ assert _execution_allowed(
+ op, ExecutionResources(gpu=1.5), ExecutionResources(gpu=2)
+ )
+ assert not _execution_allowed(
+ op, ExecutionResources(gpu=2), ExecutionResources(gpu=2)
+ )
+
+
+def test_select_ops_ensure_at_least_one_live_operator():
+ opt = ExecutionOptions()
+ inputs = make_ref_bundles([[x] for x in range(20)])
+ o1 = InputDataBuffer(inputs)
+ o2 = MapOperator(
+ make_transform(lambda block: [b * -1 for b in block]),
+ o1,
+ )
+ o3 = MapOperator(
+ make_transform(lambda block: [b * 2 for b in block]),
+ o2,
+ )
+ topo, _ = build_streaming_topology(o3, opt)
+ topo[o2].outqueue.append("dummy1")
+ o1.num_active_work_refs = MagicMock(return_value=2)
+ assert (
+ select_operator_to_run(
+ topo, ExecutionResources(cpu=1), ExecutionResources(cpu=1)
+ )
+ is None
+ )
+ o1.num_active_work_refs = MagicMock(return_value=0)
+ assert (
+ select_operator_to_run(
+ topo, ExecutionResources(cpu=1), ExecutionResources(cpu=1)
+ )
+ is o3
+ )
+
+
def test_pipelined_execution():
executor = StreamingExecutor(ExecutionOptions())
inputs = make_ref_bundles([[x] for x in range(20)])
@@ -166,31 +291,6 @@ def reverse_sort(inputs: List[RefBundle]):
assert output == expected, (output, expected)
-# TODO(ekl) remove this test once we have the new backend on by default.
-def test_e2e_streaming_sanity():
- DatasetContext.get_current().new_execution_backend = True
- DatasetContext.get_current().use_streaming_executor = True
-
- @ray.remote
- class Barrier:
- async def admit(self, x):
- if x == 4:
- print("Not allowing 4 to pass")
- await asyncio.sleep(999)
- else:
- print(f"Allowing {x} to pass")
-
- barrier = Barrier.remote()
-
- def f(x):
- ray.get(barrier.admit.remote(x))
- return x + 1
-
- # Check we can take the first items even if the last one gets stuck.
- result = ray.data.range(5, parallelism=5).map(f)
- assert result.take(4) == [1, 2, 3, 4]
-
-
if __name__ == "__main__":
import sys
diff --git a/python/ray/experimental/raysort/main.py b/python/ray/experimental/raysort/main.py
index 44270da318033..608b7b304f4c0 100644
--- a/python/ray/experimental/raysort/main.py
+++ b/python/ray/experimental/raysort/main.py
@@ -447,6 +447,7 @@ def init(args: Args):
ray.init(resources={"worker": os.cpu_count()})
else:
ray.init(address=args.ray_address)
+ logging_utils.init()
logging.info(args)
os.makedirs(constants.WORK_DIR, exist_ok=True)
resources = ray.cluster_resources()
diff --git a/python/ray/experimental/state/common.py b/python/ray/experimental/state/common.py
index 6d2545a0b6773..664041240aca0 100644
--- a/python/ray/experimental/state/common.py
+++ b/python/ray/experimental/state/common.py
@@ -368,6 +368,8 @@ class ActorState(StateSchema):
death_cause: Optional[dict] = state_column(filterable=False, detail=True)
#: True if the actor is detached. False otherwise.
is_detached: bool = state_column(filterable=False, detail=True)
+ #: The placement group id that's associated with this actor.
+ placement_group_id: str = state_column(detail=True, filterable=True)
@dataclass(init=True)
@@ -416,6 +418,12 @@ class NodeState(StateSchema):
node_name: str = state_column(filterable=True)
#: The total resources of the node.
resources_total: dict = state_column(filterable=False)
+ #: The time when the node (raylet) starts.
+ start_time_ms: int = state_column(filterable=False, detail=True)
+ #: The time when the node exits. The timestamp could be delayed
+ #: if the node is dead unexpectedly (could be delayed
+ # up to 30 seconds).
+ end_time_ms: int = state_column(filterable=False, detail=True)
class JobState(JobInfo, StateSchema):
@@ -464,9 +472,14 @@ class WorkerState(StateSchema):
#: The ip address of the worker.
ip: str = state_column(filterable=True)
#: The pid of the worker.
- pid: str = state_column(filterable=True)
+ pid: int = state_column(filterable=True)
#: The exit detail of the worker if the worker is dead.
exit_detail: Optional[str] = state_column(detail=True, filterable=False)
+ #: The time when the worker is started and initialized.
+ start_time_ms: int = state_column(filterable=False, detail=True)
+ #: The time when the worker exits. The timestamp could be delayed
+ #: if the worker is dead unexpectedly.
+ end_time_ms: int = state_column(filterable=False, detail=True)
@dataclass(init=True)
@@ -494,7 +507,7 @@ class TaskState(StateSchema):
#: Refer to src/ray/protobuf/common.proto for a detailed explanation of the state
#: breakdowns and typical state transition flow.
#:
- scheduling_state: TypeTaskStatus = state_column(filterable=True)
+ state: TypeTaskStatus = state_column(filterable=True)
#: The job id of this task.
job_id: str = state_column(filterable=True)
#: Id of the node that runs the task. If the task is retried, it could
@@ -523,6 +536,20 @@ class TaskState(StateSchema):
runtime_env_info: str = state_column(detail=True, filterable=False)
#: The parent task id.
parent_task_id: str = state_column(filterable=True)
+ #: The placement group id that's associated with this task.
+ placement_group_id: str = state_column(detail=True, filterable=True)
+ #: The worker id that's associated with this task.
+ worker_id: str = state_column(detail=True, filterable=True)
+ #: The list of events of the given task.
+ #: Refer to src/ray/protobuf/common.proto for a detailed explanation of the state
+ #: breakdowns and typical state transition flow.
+ events: List[dict] = state_column(detail=True, filterable=False)
+ #: The list of profile events of the given task.
+ profiling_data: List[dict] = state_column(detail=True, filterable=False)
+ #: The time when the task starts to run. A Unix timestamp in ms.
+ start_time_ms: Optional[int] = state_column(detail=True, filterable=False)
+ #: The time when the task finishes or failed. A Unix timestamp in ms.
+ end_time_ms: Optional[int] = state_column(detail=True, filterable=False)
@dataclass(init=True)
@@ -740,7 +767,7 @@ def to_summary(cls, *, tasks: List[Dict]):
)
task_summary = summary[key]
- state = task["scheduling_state"]
+ state = task["state"]
if state not in task_summary.state_counts:
task_summary.state_counts[state] = 0
task_summary.state_counts[state] += 1
diff --git a/python/ray/experimental/state/state_manager.py b/python/ray/experimental/state/state_manager.py
index e3b378fbd3e39..c82c0211f3112 100644
--- a/python/ray/experimental/state/state_manager.py
+++ b/python/ray/experimental/state/state_manager.py
@@ -11,6 +11,8 @@
import ray.dashboard.modules.log.log_consts as log_consts
from ray._private import ray_constants
from ray._private.gcs_utils import GcsAioClient
+from ray._private.utils import hex_to_binary
+from ray._raylet import JobID
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated.gcs_service_pb2 import (
GetAllActorInfoReply,
@@ -232,11 +234,15 @@ async def get_all_actor_info(
@handle_grpc_network_errors
async def get_all_task_info(
- self, timeout: int = None, limit: int = None
+ self, timeout: int = None, limit: int = None, job_id: Optional[str] = None
) -> Optional[GetTaskEventsReply]:
if not limit:
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
- request = GetTaskEventsRequest(limit=limit, exclude_driver_task=True)
+ if job_id:
+ job_id = JobID(hex_to_binary(job_id)).binary()
+ request = GetTaskEventsRequest(
+ limit=limit, exclude_driver_task=True, job_id=job_id
+ )
reply = await self._gcs_task_info_stub.GetTaskEvents(request, timeout=timeout)
return reply
diff --git a/python/ray/includes/global_state_accessor.pxd b/python/ray/includes/global_state_accessor.pxd
index 8e33645922539..f05d88516589b 100644
--- a/python/ray/includes/global_state_accessor.pxd
+++ b/python/ray/includes/global_state_accessor.pxd
@@ -2,6 +2,7 @@ from libcpp.string cimport string as c_string
from libcpp cimport bool as c_bool
from libcpp.vector cimport vector as c_vector
from libcpp.memory cimport unique_ptr
+from libc.stdint cimport int32_t as c_int32_t
from ray.includes.unique_ids cimport (
CActorID,
CJobID,
@@ -43,3 +44,65 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil:
CRayStatus GetNodeToConnectForDriver(
const c_string &node_ip_address,
c_string *node_to_connect)
+
+cdef extern from * namespace "ray::gcs" nogil:
+ """
+ #include
+ #include "ray/gcs/redis_client.h"
+ namespace ray {
+ namespace gcs {
+
+ class Cleanup {
+ public:
+ Cleanup(std::function f): f_(f) {}
+ ~Cleanup() { f_(); }
+ private:
+ std::function f_;
+ };
+
+ bool RedisDelKeySync(const std::string& host,
+ int32_t port,
+ const std::string& password,
+ bool use_ssl,
+ const std::string& key) {
+ RedisClientOptions options(host, port, password, false, use_ssl);
+ auto cli = std::make_unique(options);
+
+ instrumented_io_context io_service;
+
+ auto thread = std::make_unique([&]() {
+ boost::asio::io_service::work work(io_service);
+ io_service.run();
+ });
+
+ Cleanup _([&](){
+ io_service.stop();
+ thread->join();
+ });
+
+ auto status = cli->Connect(io_service);
+ if(!status.ok()) {
+ RAY_LOG(ERROR) << "Failed to connect to redis: " << status.ToString();
+ return false;
+ }
+
+ auto context = cli->GetShardContext(key);
+ auto cmd = std::vector{"DEL", key};
+ auto reply = context->RunArgvSync(cmd);
+ if(reply->ReadAsInteger() == 1) {
+ RAY_LOG(INFO) << "Successfully deleted " << key;
+ return true;
+ } else {
+ RAY_LOG(ERROR) << "Failed to delete " << key;
+ return false;
+ }
+ }
+
+ }
+ }
+ """
+ c_bool RedisDelKeySync(const c_string& host,
+ c_int32_t port,
+ const c_string& password,
+ c_bool use_ssl,
+ const c_string& key)
diff --git a/python/ray/includes/global_state_accessor.pxi b/python/ray/includes/global_state_accessor.pxi
index f98a5181d60ac..8492ee56a89bb 100644
--- a/python/ray/includes/global_state_accessor.pxi
+++ b/python/ray/includes/global_state_accessor.pxi
@@ -12,6 +12,7 @@ from ray.includes.unique_ids cimport (
from ray.includes.global_state_accessor cimport (
CGlobalStateAccessor,
+ RedisDelKeySync,
)
from libcpp.string cimport string as c_string
diff --git a/python/ray/serve/_private/client.py b/python/ray/serve/_private/client.py
index 9e561008c13ae..36cfad1e45a51 100644
--- a/python/ray/serve/_private/client.py
+++ b/python/ray/serve/_private/client.py
@@ -484,7 +484,7 @@ def get_deploy_args(
"deployment_config_proto_bytes": deployment_config.to_proto_bytes(),
"replica_config_proto_bytes": replica_config.to_proto_bytes(),
"route_prefix": route_prefix,
- "deployer_job_id": ray.get_runtime_context().job_id,
+ "deployer_job_id": ray.get_runtime_context().get_job_id(),
"is_driver_deployment": is_driver_deployment,
}
diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py
index a04175881c034..439bec125495a 100644
--- a/python/ray/serve/_private/common.py
+++ b/python/ray/serve/_private/common.py
@@ -162,7 +162,7 @@ def __init__(
deployment_config: DeploymentConfig,
replica_config: ReplicaConfig,
start_time_ms: int,
- deployer_job_id: "ray._raylet.JobID",
+ deployer_job_id: str,
actor_name: Optional[str] = None,
version: Optional[str] = None,
end_time_ms: Optional[int] = None,
@@ -225,7 +225,7 @@ def from_proto(cls, proto: DeploymentInfoProto):
"actor_name": proto.actor_name if proto.actor_name != "" else None,
"version": proto.version if proto.version != "" else None,
"end_time_ms": proto.end_time_ms if proto.end_time_ms != 0 else None,
- "deployer_job_id": ray.get_runtime_context().job_id,
+ "deployer_job_id": ray.get_runtime_context().get_job_id(),
}
return cls(**data)
diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py
index 7b12e4db82d6f..45b8652413957 100644
--- a/python/ray/serve/_private/deployment_state.py
+++ b/python/ray/serve/_private/deployment_state.py
@@ -473,7 +473,7 @@ def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[DeploymentVersion]
)
self._health_check_period_s = deployment_config.health_check_period_s
self._health_check_timeout_s = deployment_config.health_check_timeout_s
- self._node_id = ray.get(self._allocated_obj_ref).hex()
+ self._node_id = ray.get(self._allocated_obj_ref)
except Exception:
logger.exception(f"Exception in deployment '{self._deployment_name}'")
return ReplicaStartupStatus.FAILED, None
diff --git a/python/ray/serve/_private/http_proxy.py b/python/ray/serve/_private/http_proxy.py
index ab071332409b5..b00139d7be287 100644
--- a/python/ray/serve/_private/http_proxy.py
+++ b/python/ray/serve/_private/http_proxy.py
@@ -37,10 +37,22 @@
SOCKET_REUSE_PORT_ENABLED = (
os.environ.get("SERVE_SOCKET_REUSE_PORT_ENABLED", "1") == "1"
)
-SERVE_REQUEST_PROCESSING_TIMEOUT_S = (
- float(os.environ.get("SERVE_REQUEST_PROCESSING_TIMEOUT_S", 0)) or None
+
+# TODO (shrekris-anyscale): Deprecate SERVE_REQUEST_PROCESSING_TIMEOUT_S env var
+RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S = (
+ float(os.environ.get("RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S", 0))
+ or float(os.environ.get("SERVE_REQUEST_PROCESSING_TIMEOUT_S", 0))
+ or None
)
+if os.environ.get("SERVE_REQUEST_PROCESSING_TIMEOUT_S") is not None:
+ logger.warning(
+ "The `SERVE_REQUEST_PROCESSING_TIMEOUT_S` environment variable has "
+ "been deprecated. Please use `RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S` "
+ "instead. `SERVE_REQUEST_PROCESSING_TIMEOUT_S` will be ignored in "
+ "future versions."
+ )
+
async def _send_request_to_handle(handle, scope, receive, send) -> str:
http_body_bytes = await receive_http_body(scope, receive, send)
@@ -90,14 +102,14 @@ async def _send_request_to_handle(handle, scope, receive, send) -> str:
# https://github.com/ray-project/ray/pull/29534 for more info.
_, request_timed_out = await asyncio.wait(
- [object_ref], timeout=SERVE_REQUEST_PROCESSING_TIMEOUT_S
+ [object_ref], timeout=RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S
)
if request_timed_out:
logger.info(
"Request didn't finish within "
- f"{SERVE_REQUEST_PROCESSING_TIMEOUT_S} seconds. Retrying "
+ f"{RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S} seconds. Retrying "
"with another replica. You can modify this timeout by "
- 'setting the "SERVE_REQUEST_PROCESSING_TIMEOUT_S" env var.'
+ 'setting the "RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S" env var.'
)
backoff = True
else:
@@ -421,11 +433,12 @@ async def ready(self):
"""Returns when HTTP proxy is ready to serve traffic.
Or throw exception when it is not able to serve traffic.
"""
+ setup_task = get_or_create_event_loop().create_task(self.setup_complete.wait())
done_set, _ = await asyncio.wait(
[
# Either the HTTP setup has completed.
# The event is set inside self.run.
- self.setup_complete.wait(),
+ setup_task,
# Or self.run errored.
self.running_task,
],
diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py
index c87b7ce9f7bb6..eb645330fdd5c 100644
--- a/python/ray/serve/_private/replica.py
+++ b/python/ray/serve/_private/replica.py
@@ -212,7 +212,7 @@ async def is_allocated(self) -> str:
Return the NodeID of this replica
"""
- return ray.get_runtime_context().node_id
+ return ray.get_runtime_context().get_node_id()
async def is_initialized(
self, user_config: Optional[Any] = None, _after: Optional[Any] = None
diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py
index dff31a2c8f587..9dd231c84166c 100644
--- a/python/ray/serve/controller.py
+++ b/python/ray/serve/controller.py
@@ -271,7 +271,7 @@ def _put_serve_snapshot(self) -> None:
entry = dict()
entry["name"] = deployment_name
entry["namespace"] = ray.get_runtime_context().namespace
- entry["ray_job_id"] = deployment_info.deployer_job_id.hex()
+ entry["ray_job_id"] = deployment_info.deployer_job_id
entry["class_name"] = deployment_info.replica_config.deployment_def_name
entry["version"] = deployment_info.version
entry["http_route"] = route_prefix
@@ -351,7 +351,7 @@ def deploy(
deployment_config_proto_bytes: bytes,
replica_config_proto_bytes: bytes,
route_prefix: Optional[str],
- deployer_job_id: Union["ray._raylet.JobID", bytes],
+ deployer_job_id: Union[str, bytes],
is_driver_deployment: Optional[bool] = False,
) -> bool:
if route_prefix is not None:
@@ -381,10 +381,13 @@ def deploy(
autoscaling_policy = BasicAutoscalingPolicy(autoscaling_config)
else:
autoscaling_policy = None
+
+ # Java API passes in JobID as bytes
if isinstance(deployer_job_id, bytes):
deployer_job_id = ray.JobID.from_int(
int.from_bytes(deployer_job_id, "little")
- )
+ ).hex()
+
deployment_info = DeploymentInfo(
actor_name=name,
version=version,
diff --git a/python/ray/serve/scripts.py b/python/ray/serve/scripts.py
index 7a54664e8626a..ec5cfee81a8d2 100644
--- a/python/ray/serve/scripts.py
+++ b/python/ray/serve/scripts.py
@@ -272,7 +272,11 @@ def deploy(config_file_name: str, address: str):
@click.option(
"--gradio",
is_flag=True,
- help=("Whether to enable gradio visualization of deployment graph."),
+ help=(
+ "Whether to enable gradio visualization of deployment graph. The "
+ "visualization can only be used with deployment graphs with DAGDriver "
+ "as the ingress deployment."
+ ),
)
def run(
config_or_import_path: str,
diff --git a/python/ray/serve/tests/test_cross_language.py b/python/ray/serve/tests/test_cross_language.py
index 72be1ffdce24c..916b4f31f06c7 100644
--- a/python/ray/serve/tests/test_cross_language.py
+++ b/python/ray/serve/tests/test_cross_language.py
@@ -37,7 +37,7 @@ def test_controller_starts_java_replica(shutdown_only): # noqa: F811
deployment_config_proto_bytes=config.to_proto_bytes(),
replica_config_proto_bytes=replica_config.to_proto_bytes(),
route_prefix=None,
- deployer_job_id=ray.get_runtime_context().job_id,
+ deployer_job_id=ray.get_runtime_context().get_job_id(),
)
)
assert updating
diff --git a/python/ray/serve/tests/test_deployment_state.py b/python/ray/serve/tests/test_deployment_state.py
index 99dadb1290030..1ce865997d96e 100644
--- a/python/ray/serve/tests/test_deployment_state.py
+++ b/python/ray/serve/tests/test_deployment_state.py
@@ -180,7 +180,7 @@ def deployment_info(
num_replicas=num_replicas, user_config=user_config, **config_opts
),
replica_config=ReplicaConfig.create(lambda x: x),
- deployer_job_id=ray.JobID.nil(),
+ deployer_job_id="",
is_driver_deployment=is_driver_deployment,
)
diff --git a/python/ray/serve/tests/test_standalone.py b/python/ray/serve/tests/test_standalone.py
index b2f8b08194eb4..2a00ddb952648 100644
--- a/python/ray/serve/tests/test_standalone.py
+++ b/python/ray/serve/tests/test_standalone.py
@@ -143,7 +143,7 @@ def test_detached_deployment(ray_cluster):
# Create first job, check we can run a simple serve endpoint
ray.init(head_node.address, namespace=SERVE_NAMESPACE)
- first_job_id = ray.get_runtime_context().job_id
+ first_job_id = ray.get_runtime_context().get_job_id()
serve.start(detached=True)
@serve.deployment(route_prefix="/say_hi_f")
@@ -159,7 +159,7 @@ def f(*args):
# Create the second job, make sure we can still create new deployments.
ray.init(head_node.address, namespace="serve")
- assert ray.get_runtime_context().job_id != first_job_id
+ assert ray.get_runtime_context().get_job_id() != first_job_id
@serve.deployment(route_prefix="/say_hi_g")
def g(*args):
diff --git a/python/ray/serve/tests/test_standalone2.py b/python/ray/serve/tests/test_standalone2.py
index cf726eccd6204..35958d36670ac 100644
--- a/python/ray/serve/tests/test_standalone2.py
+++ b/python/ray/serve/tests/test_standalone2.py
@@ -748,7 +748,16 @@ def f():
class TestServeRequestProcessingTimeoutS:
@pytest.mark.parametrize(
- "ray_instance", [{"SERVE_REQUEST_PROCESSING_TIMEOUT_S": "5"}], indirect=True
+ "ray_instance",
+ [
+ {"RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S": "5"},
+ {"SERVE_REQUEST_PROCESSING_TIMEOUT_S": "5"},
+ {
+ "RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S": "5",
+ "SERVE_REQUEST_PROCESSING_TIMEOUT_S": "0",
+ },
+ ],
+ indirect=True,
)
def test_normal_operation(self, ray_instance):
"""Checks that a moderate timeout doesn't affect normal operation."""
@@ -765,7 +774,16 @@ def f(*args):
serve.shutdown()
@pytest.mark.parametrize(
- "ray_instance", [{"SERVE_REQUEST_PROCESSING_TIMEOUT_S": "0.1"}], indirect=True
+ "ray_instance",
+ [
+ {"RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S": "0.1"},
+ {"SERVE_REQUEST_PROCESSING_TIMEOUT_S": "0.1"},
+ {
+ "RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S": "0.1",
+ "SERVE_REQUEST_PROCESSING_TIMEOUT_S": "0",
+ },
+ ],
+ indirect=True,
)
def test_hanging_request(self, ray_instance):
"""Checks that the env var mitigates the hang."""
diff --git a/python/ray/setup-dev.py b/python/ray/setup-dev.py
index 48b54fa464bda..e87cf6da8ac1b 100755
--- a/python/ray/setup-dev.py
+++ b/python/ray/setup-dev.py
@@ -1,11 +1,22 @@
#!/usr/bin/env python
+# flake8: noqa E402
"""This script allows you to develop Ray Python code without needing to compile
Ray.
See https://docs.ray.io/en/master/development.html#building-ray-python-only"""
+import os
+import sys
+
+# types.py can conflict with stdlib's types.py in some python versions,
+# see https://github.com/python/cpython/issues/101210.
+# To avoid import errors, we move the current working dir to the end of sys.path.
+this_dir = os.path.dirname(__file__)
+if this_dir in sys.path:
+ cur = sys.path.remove(this_dir)
+ sys.path.append(this_dir)
+
import argparse
import click
-import os
import shutil
import subprocess
diff --git a/python/ray/tests/kuberay/utils.py b/python/ray/tests/kuberay/utils.py
index 892530912051d..345cc09bcf6d8 100644
--- a/python/ray/tests/kuberay/utils.py
+++ b/python/ray/tests/kuberay/utils.py
@@ -33,18 +33,6 @@
def setup_logging():
- """Set up logging for kuberay.
-
- For kuberay's autoscaler integration, the autoscaler runs in a sidecar container
- in the same pod as the main Ray container, which runs the rest of the Ray
- processes.
-
- The logging configuration here is for the sidecar container, but we need the
- logs to go to the same place as the head node logs because the autoscaler is
- allowed to send scaling events to Ray drivers' stdout. The implementation of
- this feature involves the autoscaler communicating to another Ray process
- (the log monitor) via logs in that directory.
- """
logging.basicConfig(
level=logging.INFO,
format=LOG_FORMAT,
diff --git a/python/ray/tests/spark/test_GPU.py b/python/ray/tests/spark/test_GPU.py
index 8de58bc2e2ec0..de338fc2717d3 100644
--- a/python/ray/tests/spark/test_GPU.py
+++ b/python/ray/tests/spark/test_GPU.py
@@ -5,10 +5,9 @@
import functools
from abc import ABC
from pyspark.sql import SparkSession
-from ray.tests.spark.test_basic import RayOnSparkCPUClusterTestBase
+from ray.tests.spark.test_basic import RayOnSparkCPUClusterTestBase, _setup_ray_cluster
import ray
-from ray.util.spark.cluster_init import _init_ray_cluster
pytestmark = pytest.mark.skipif(
not sys.platform.startswith("linux"),
@@ -22,40 +21,64 @@ class RayOnSparkGPUClusterTestBase(RayOnSparkCPUClusterTestBase, ABC):
num_gpus_per_spark_task = None
def test_gpu_allocation(self):
-
- for num_spark_tasks in [self.max_spark_tasks // 2, self.max_spark_tasks]:
- with _init_ray_cluster(num_worker_nodes=num_spark_tasks, safe_mode=False):
+ for num_worker_nodes, num_cpus_per_node, num_gpus_per_node in [
+ (
+ self.max_spark_tasks // 2,
+ self.num_cpus_per_spark_task,
+ self.num_gpus_per_spark_task,
+ ),
+ (
+ self.max_spark_tasks,
+ self.num_cpus_per_spark_task,
+ self.num_gpus_per_spark_task,
+ ),
+ (
+ self.max_spark_tasks // 2,
+ self.num_cpus_per_spark_task * 2,
+ self.num_gpus_per_spark_task * 2,
+ ),
+ (
+ self.max_spark_tasks // 2,
+ self.num_cpus_per_spark_task,
+ self.num_gpus_per_spark_task * 2,
+ ),
+ ]:
+ with _setup_ray_cluster(
+ num_worker_nodes=num_worker_nodes,
+ num_cpus_per_node=num_cpus_per_node,
+ num_gpus_per_node=num_gpus_per_node,
+ head_node_options={"include_dashboard": False},
+ ):
+ ray.init()
worker_res_list = self.get_ray_worker_resources_list()
- assert len(worker_res_list) == num_spark_tasks
+ assert len(worker_res_list) == num_worker_nodes
for worker_res in worker_res_list:
- assert worker_res["GPU"] == self.num_gpus_per_spark_task
-
- def test_basic_ray_app_using_gpu(self):
-
- with _init_ray_cluster(num_worker_nodes=self.max_spark_tasks, safe_mode=False):
+ assert worker_res["CPU"] == num_cpus_per_node
+ assert worker_res["GPU"] == num_gpus_per_node
- @ray.remote(num_cpus=1, num_gpus=1)
- def f(_):
- # Add a sleep to avoid the task finishing too fast,
- # so that it can make all ray tasks concurrently running in all idle
- # task slots.
- time.sleep(5)
- return [
- int(gpu_id)
- for gpu_id in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
- ]
+ @ray.remote(num_cpus=num_cpus_per_node, num_gpus=num_gpus_per_node)
+ def f(_):
+ # Add a sleep to avoid the task finishing too fast,
+ # so that it can make all ray tasks concurrently running in all idle
+ # task slots.
+ time.sleep(5)
+ return [
+ int(gpu_id)
+ for gpu_id in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
+ ]
- futures = [f.remote(i) for i in range(self.num_total_gpus)]
- results = ray.get(futures)
- merged_results = functools.reduce(lambda x, y: x + y, results)
- # Test all ray tasks are assigned with different GPUs.
- assert sorted(merged_results) == list(range(self.num_total_gpus))
+ futures = [f.remote(i) for i in range(num_worker_nodes)]
+ results = ray.get(futures)
+ merged_results = functools.reduce(lambda x, y: x + y, results)
+ # Test all ray tasks are assigned with different GPUs.
+ assert sorted(merged_results) == list(
+ range(num_gpus_per_node * num_worker_nodes)
+ )
class TestBasicSparkGPUCluster(RayOnSparkGPUClusterTestBase):
@classmethod
def setup_class(cls):
- super().setup_class()
cls.num_total_cpus = 2
cls.num_total_gpus = 2
cls.num_cpus_per_spark_task = 1
@@ -76,6 +99,8 @@ def setup_class(cls):
.config(
"spark.worker.resource.gpu.discoveryScript", gpu_discovery_script_path
)
+ .config("spark.executorEnv.RAY_ON_SPARK_WORKER_CPU_CORES", "2")
+ .config("spark.executorEnv.RAY_ON_SPARK_WORKER_GPU_NUM", "2")
.getOrCreate()
)
diff --git a/python/ray/tests/spark/test_basic.py b/python/ray/tests/spark/test_basic.py
index 9b3d9c6a6e13b..dac18bab16049 100644
--- a/python/ray/tests/spark/test_basic.py
+++ b/python/ray/tests/spark/test_basic.py
@@ -10,12 +10,23 @@
import ray
import ray.util.spark.cluster_init
-from ray.util.spark import init_ray_cluster, shutdown_ray_cluster
-from ray.util.spark.cluster_init import _init_ray_cluster
+from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster, MAX_NUM_WORKER_NODES
from ray.util.spark.utils import check_port_open
from pyspark.sql import SparkSession
import time
import logging
+from contextlib import contextmanager
+
+
+@contextmanager
+def _setup_ray_cluster(*args, **kwds):
+ # Code to acquire resource, e.g.:
+ setup_ray_cluster(*args, **kwds)
+ try:
+ yield ray.util.spark.cluster_init._active_ray_cluster
+ finally:
+ shutdown_ray_cluster()
+
pytestmark = pytest.mark.skipif(
not sys.platform.startswith("linux"),
@@ -32,10 +43,6 @@ class RayOnSparkCPUClusterTestBase(ABC):
num_cpus_per_spark_task = None
max_spark_tasks = None
- @classmethod
- def setup_class(cls):
- pass
-
@classmethod
def teardown_class(cls):
time.sleep(10) # Wait all background spark job canceled.
@@ -51,23 +58,41 @@ def get_ray_worker_resources_list():
return wr_list
def test_cpu_allocation(self):
- for num_spark_tasks in [self.max_spark_tasks // 2, self.max_spark_tasks]:
- with _init_ray_cluster(num_worker_nodes=num_spark_tasks, safe_mode=False):
+ for num_worker_nodes, num_cpus_per_node, num_worker_nodes_arg in [
+ (
+ self.max_spark_tasks // 2,
+ self.num_cpus_per_spark_task,
+ self.max_spark_tasks // 2,
+ ),
+ (self.max_spark_tasks, self.num_cpus_per_spark_task, MAX_NUM_WORKER_NODES),
+ (
+ self.max_spark_tasks // 2,
+ self.num_cpus_per_spark_task * 2,
+ MAX_NUM_WORKER_NODES,
+ ),
+ ]:
+ with _setup_ray_cluster(
+ num_worker_nodes=num_worker_nodes_arg,
+ num_cpus_per_node=num_cpus_per_node,
+ head_node_options={"include_dashboard": False},
+ ):
+ ray.init()
worker_res_list = self.get_ray_worker_resources_list()
- assert len(worker_res_list) == num_spark_tasks
+ assert len(worker_res_list) == num_worker_nodes
for worker_res in worker_res_list:
- assert worker_res["CPU"] == self.num_cpus_per_spark_task
+ assert worker_res["CPU"] == num_cpus_per_node
def test_public_api(self):
try:
ray_temp_root_dir = tempfile.mkdtemp()
collect_log_to_path = tempfile.mkdtemp()
- init_ray_cluster(
- num_worker_nodes=self.max_spark_tasks,
- safe_mode=False,
+ setup_ray_cluster(
+ num_worker_nodes=MAX_NUM_WORKER_NODES,
collect_log_to_path=collect_log_to_path,
ray_temp_root_dir=ray_temp_root_dir,
+ head_node_options={"include_dashboard": True},
)
+ ray.init()
@ray.remote
def f(x):
@@ -102,9 +127,8 @@ def f(x):
shutil.rmtree(collect_log_to_path, ignore_errors=True)
def test_ray_cluster_shutdown(self):
- with _init_ray_cluster(
- num_worker_nodes=self.max_spark_tasks, safe_mode=False
- ) as cluster:
+ with _setup_ray_cluster(num_worker_nodes=self.max_spark_tasks) as cluster:
+ ray.init()
assert len(self.get_ray_worker_resources_list()) == self.max_spark_tasks
# Test: cancel background spark job will cause all ray worker nodes exit.
@@ -119,9 +143,8 @@ def test_ray_cluster_shutdown(self):
assert not check_port_open(hostname, int(port))
def test_background_spark_job_exit_trigger_ray_head_exit(self):
- with _init_ray_cluster(
- num_worker_nodes=self.max_spark_tasks, safe_mode=False
- ) as cluster:
+ with _setup_ray_cluster(num_worker_nodes=self.max_spark_tasks) as cluster:
+ ray.init()
# Mimic the case the job failed unexpectedly.
cluster._cancel_background_spark_job()
cluster.spark_job_is_canceled = False
@@ -135,7 +158,6 @@ def test_background_spark_job_exit_trigger_ray_head_exit(self):
class TestBasicSparkCluster(RayOnSparkCPUClusterTestBase):
@classmethod
def setup_class(cls):
- super().setup_class()
cls.num_total_cpus = 2
cls.num_total_gpus = 0
cls.num_cpus_per_spark_task = 1
@@ -146,6 +168,7 @@ def setup_class(cls):
SparkSession.builder.master("local-cluster[1, 2, 1024]")
.config("spark.task.cpus", "1")
.config("spark.task.maxFailures", "1")
+ .config("spark.executorEnv.RAY_ON_SPARK_WORKER_CPU_CORES", "2")
.getOrCreate()
)
diff --git a/python/ray/tests/spark/test_multicores_per_task.py b/python/ray/tests/spark/test_multicores_per_task.py
index 95bdee432c893..5f9fa3b805313 100644
--- a/python/ray/tests/spark/test_multicores_per_task.py
+++ b/python/ray/tests/spark/test_multicores_per_task.py
@@ -13,7 +13,6 @@
class TestMultiCoresPerTaskCluster(RayOnSparkGPUClusterTestBase):
@classmethod
def setup_class(cls):
- super().setup_class()
cls.num_total_cpus = 4
cls.num_total_gpus = 4
cls.num_cpus_per_spark_task = 2
@@ -34,6 +33,8 @@ def setup_class(cls):
.config(
"spark.worker.resource.gpu.discoveryScript", gpu_discovery_script_path
)
+ .config("spark.executorEnv.RAY_ON_SPARK_WORKER_CPU_CORES", "4")
+ .config("spark.executorEnv.RAY_ON_SPARK_WORKER_GPU_NUM", "4")
.getOrCreate()
)
diff --git a/python/ray/tests/spark/test_utils.py b/python/ray/tests/spark/test_utils.py
index fa7fe1b93942f..d9c7e570483f6 100644
--- a/python/ray/tests/spark/test_utils.py
+++ b/python/ray/tests/spark/test_utils.py
@@ -6,6 +6,7 @@
from ray.util.spark.utils import (
get_spark_task_assigned_physical_gpus,
_calc_mem_per_ray_worker_node,
+ _get_avail_mem_per_ray_worker_node,
)
pytestmark = pytest.mark.skipif(
@@ -23,10 +24,66 @@ def test_get_spark_task_assigned_physical_gpus():
assert get_spark_task_assigned_physical_gpus([0, 2]) == [2, 6]
+@patch("ray._private.ray_constants.OBJECT_STORE_MINIMUM_MEMORY_BYTES", 1)
def test_calc_mem_per_ray_worker_node():
- assert _calc_mem_per_ray_worker_node(4, 1000000, 400000, 100000) == (120000, 80000)
- assert _calc_mem_per_ray_worker_node(4, 1000000, 400000, 70000) == (130000, 70000)
- assert _calc_mem_per_ray_worker_node(4, 1000000, 400000, None) == (120000, 80000)
+ assert _calc_mem_per_ray_worker_node(4, 1000000, 400000, 100000) == (
+ 120000,
+ 80000,
+ None,
+ )
+ assert _calc_mem_per_ray_worker_node(4, 1000000, 400000, 70000) == (
+ 130000,
+ 70000,
+ None,
+ )
+ assert _calc_mem_per_ray_worker_node(4, 1000000, 400000, None) == (
+ 140000,
+ 60000,
+ None,
+ )
+ assert _calc_mem_per_ray_worker_node(4, 1000000, 200000, None) == (
+ 160000,
+ 40000,
+ None,
+ )
+
+
+@patch("ray._private.ray_constants.OBJECT_STORE_MINIMUM_MEMORY_BYTES", 1)
+def test_get_avail_mem_per_ray_worker_node(monkeypatch):
+ monkeypatch.setenv("RAY_ON_SPARK_WORKER_CPU_CORES", "4")
+ monkeypatch.setenv("RAY_ON_SPARK_WORKER_GPU_NUM", "8")
+ monkeypatch.setenv("RAY_ON_SPARK_WORKER_PHYSICAL_MEMORY_BYTES", "1000000")
+ monkeypatch.setenv("RAY_ON_SPARK_WORKER_SHARED_MEMORY_BYTES", "500000")
+
+ assert _get_avail_mem_per_ray_worker_node(
+ num_cpus_per_node=1,
+ num_gpus_per_node=2,
+ object_store_memory_per_node=None,
+ ) == (140000, 60000, None, None)
+
+ assert _get_avail_mem_per_ray_worker_node(
+ num_cpus_per_node=1,
+ num_gpus_per_node=2,
+ object_store_memory_per_node=80000,
+ ) == (120000, 80000, None, None)
+
+ assert _get_avail_mem_per_ray_worker_node(
+ num_cpus_per_node=1,
+ num_gpus_per_node=2,
+ object_store_memory_per_node=120000,
+ ) == (100000, 100000, None, None)
+
+ assert _get_avail_mem_per_ray_worker_node(
+ num_cpus_per_node=2,
+ num_gpus_per_node=2,
+ object_store_memory_per_node=None,
+ ) == (280000, 120000, None, None)
+
+ assert _get_avail_mem_per_ray_worker_node(
+ num_cpus_per_node=1,
+ num_gpus_per_node=4,
+ object_store_memory_per_node=None,
+ ) == (280000, 120000, None, None)
if __name__ == "__main__":
diff --git a/python/ray/tests/test_advanced.py b/python/ray/tests/test_advanced.py
index 104714334865c..447d8cbabbddd 100644
--- a/python/ray/tests/test_advanced.py
+++ b/python/ray/tests/test_advanced.py
@@ -193,7 +193,7 @@ def f(worker_info):
ray._private.worker.global_worker.run_function_on_all_workers(f)
-@pytest.mark.skip(reason="Flaky tests")
+@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_running_function_on_all_workers(ray_start_regular):
def f(worker_info):
sys.path.append("fake_directory")
diff --git a/python/ray/tests/test_autoscaler.py b/python/ray/tests/test_autoscaler.py
index 4a639b1b69cc8..3a8835c99488a 100644
--- a/python/ray/tests/test_autoscaler.py
+++ b/python/ray/tests/test_autoscaler.py
@@ -1,4 +1,5 @@
import copy
+import logging
import sys
import json
import os
@@ -3801,6 +3802,41 @@ def testInitializeSDKArguments(self):
with self.assertRaises(TypeError):
request_resources(bundles=[{"foo": 1}, {"bar": "baz"}])
+ def test_autoscaler_status_log(self):
+ self._test_autoscaler_status_log(status_log_enabled_env=1)
+ self._test_autoscaler_status_log(status_log_enabled_env=0)
+
+ def _test_autoscaler_status_log(self, status_log_enabled_env: int):
+ mock_logger = Mock(spec=logging.Logger(""))
+ with patch.multiple(
+ "ray.autoscaler._private.autoscaler",
+ logger=mock_logger,
+ AUTOSCALER_STATUS_LOG=status_log_enabled_env,
+ ):
+ config = copy.deepcopy(SMALL_CLUSTER)
+ config_path = self.write_config(config)
+ runner = MockProcessRunner()
+ mock_metrics = Mock(spec=AutoscalerPrometheusMetrics())
+ self.provider = MockProvider()
+ autoscaler = MockAutoscaler(
+ config_path,
+ LoadMetrics(),
+ MockNodeInfoStub(),
+ max_failures=0,
+ process_runner=runner,
+ update_interval_s=0,
+ prom_metrics=mock_metrics,
+ )
+ autoscaler.update()
+ status_log_found = False
+ for call in mock_logger.info.call_args_list:
+ args, _ = call
+ arg = args[0]
+ if " Autoscaler status: " in arg:
+ status_log_found = True
+ break
+ assert status_log_found is bool(status_log_enabled_env)
+
def test_import():
"""This test ensures that all the autoscaler imports work as expected to
diff --git a/python/ray/tests/test_basic_5.py b/python/ray/tests/test_basic_5.py
index 083df25b53c91..5571bc5277448 100644
--- a/python/ray/tests/test_basic_5.py
+++ b/python/ray/tests/test_basic_5.py
@@ -142,7 +142,7 @@ def pid(self):
assert "Traceback" not in log
-@pytest.mark.skipif(True, reason="run_function_on_all_workers doesn't work")
+@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on windows")
def test_run_on_all_workers(call_ray_start, tmp_path):
# This test is to ensure run_function_on_all_workers are executed
# on all workers.
@@ -247,7 +247,7 @@ def get_kv_metrics():
???? # unknown
"""
# !!!If you want to increase this number, please let ray-core knows this!!!
- assert freqs["internal_kv_get"] == 5
+ assert freqs["internal_kv_get"] == 4
if __name__ == "__main__":
diff --git a/python/ray/tests/test_exit_observability.py b/python/ray/tests/test_exit_observability.py
index d1f1b258b2583..ebb52cd97e345 100644
--- a/python/ray/tests/test_exit_observability.py
+++ b/python/ray/tests/test_exit_observability.py
@@ -6,7 +6,7 @@
import ray
from ray._private.test_utils import run_string_as_driver, wait_for_condition
-from ray.experimental.state.api import list_workers
+from ray.experimental.state.api import list_workers, list_nodes
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@@ -314,6 +314,82 @@ def verify_exit_by_actor_init_failure():
wait_for_condition(verify_exit_by_actor_init_failure)
+@pytest.mark.skipif(
+ sys.platform == "win32",
+ reason="Failed on Windows because sigkill doesn't work on Windows",
+)
+def test_worker_start_end_time(shutdown_only):
+ ray.init(num_cpus=1)
+
+ @ray.remote
+ class Worker:
+ def ready(self):
+ return os.getpid()
+
+ # Test normal exit.
+ worker = Worker.remote()
+ pid = ray.get(worker.ready.remote())
+ workers = list_workers(detail=True, filters=[("pid", "=", pid)])[0]
+ print(workers)
+ assert workers["start_time_ms"] > 0
+ assert workers["end_time_ms"] == 0
+
+ ray.kill(worker)
+ workers = list_workers(detail=True, filters=[("pid", "=", pid)])[0]
+ assert workers["start_time_ms"] > 0
+ assert workers["end_time_ms"] > 0
+
+ # Test unexpected exit.
+ worker = Worker.remote()
+ pid = ray.get(worker.ready.remote())
+ os.kill(pid, signal.SIGKILL)
+ workers = list_workers(detail=True, filters=[("pid", "=", pid)])[0]
+ assert workers["start_time_ms"] > 0
+ assert workers["end_time_ms"] > 0
+
+
+def test_node_start_end_time(ray_start_cluster):
+ cluster = ray_start_cluster
+ # head
+ cluster.add_node(num_cpus=0)
+ nodes = list_nodes(detail=True)
+ head_node_id = nodes[0]["node_id"]
+
+ worker_node = cluster.add_node(num_cpus=0)
+ nodes = list_nodes(detail=True)
+ worker_node_data = list(
+ filter(lambda x: x["node_id"] != head_node_id and x["state"] == "ALIVE", nodes)
+ )[0]
+ assert worker_node_data["start_time_ms"] > 0
+ assert worker_node_data["end_time_ms"] == 0
+
+ # Test expected exit.
+ cluster.remove_node(worker_node, allow_graceful=True)
+ nodes = list_nodes(detail=True)
+ worker_node_data = list(
+ filter(lambda x: x["node_id"] != head_node_id and x["state"] == "DEAD", nodes)
+ )[0]
+ assert worker_node_data["start_time_ms"] > 0
+ assert worker_node_data["end_time_ms"] > 0
+
+ # Test unexpected exit.
+ worker_node = cluster.add_node(num_cpus=0)
+ nodes = list_nodes(detail=True)
+ worker_node_data = list(
+ filter(lambda x: x["node_id"] != head_node_id and x["state"] == "ALIVE", nodes)
+ )[0]
+ assert worker_node_data["start_time_ms"] > 0
+ assert worker_node_data["end_time_ms"] == 0
+
+ cluster.remove_node(worker_node, allow_graceful=False)
+ nodes = list_nodes(detail=True)
+ worker_node_data = list(
+ filter(lambda x: x["node_id"] != head_node_id and x["state"] == "DEAD", nodes)
+ )[0]
+ assert worker_node_data["start_time_ms"] > 0
+ assert worker_node_data["end_time_ms"] > 0
+
+
if __name__ == "__main__":
import pytest
diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py
index 6abc9d9592f61..71bb7a98dd9ad 100644
--- a/python/ray/tests/test_failure.py
+++ b/python/ray/tests/test_failure.py
@@ -125,7 +125,6 @@ def expect_exception(objects, exception):
ray.get(signal2.send.remote())
-@pytest.mark.skipif(True, reason="run_function_on_all_workers doesn't work")
def test_failed_function_to_run(ray_start_2_cpus, error_pubsub):
p = error_pubsub
@@ -134,6 +133,20 @@ def f(worker):
raise Exception("Function to run failed.")
ray._private.worker.global_worker.run_function_on_all_workers(f)
+
+ @ray.remote
+ class Actor:
+ def foo(self):
+ pass
+
+ # Functions scheduled through run_function_on_all_workers only
+ # executes on workers binded with current driver's job_id.
+ # Since the 2 prestarted workers lazily bind to job_id until the first
+ # task/actor executed, we need to schedule two actors to trigger
+ # prestart functions.
+ actors = [Actor.remote() for _ in range(2)]
+ ray.get([actor.foo.remote() for actor in actors])
+
# Check that the error message is in the task info.
errors = get_error_message(p, 2, ray_constants.FUNCTION_TO_RUN_PUSH_ERROR)
assert len(errors) == 2
diff --git a/python/ray/tests/test_gcs_utils.py b/python/ray/tests/test_gcs_utils.py
index a2f9a5037b6ad..cbc3b2d757fde 100644
--- a/python/ray/tests/test_gcs_utils.py
+++ b/python/ray/tests/test_gcs_utils.py
@@ -7,6 +7,7 @@
import grpc
import pytest
import ray
+import redis
from ray._private.gcs_utils import GcsClient
import ray._private.gcs_utils as gcs_utils
from ray._private.test_utils import (
@@ -211,6 +212,44 @@ async def check(expect_liveness):
)
+@pytest.fixture(params=[True, False])
+def redis_replicas(request, monkeypatch):
+ if request.param:
+ monkeypatch.setenv("TEST_EXTERNAL_REDIS_REPLICAS", "3")
+ yield
+
+
+@pytest.mark.skipif(
+ not enable_external_redis(), reason="Only valid when start with an external redis"
+)
+def test_redis_cleanup(redis_replicas, shutdown_only):
+ addr = ray.init(
+ namespace="a", _system_config={"external_storage_namespace": "c1"}
+ ).address_info["address"]
+ gcs_client = GcsClient(address=addr)
+ gcs_client.internal_kv_put(b"ABC", b"DEF", True, None)
+
+ ray.shutdown()
+ addr = ray.init(
+ namespace="a", _system_config={"external_storage_namespace": "c2"}
+ ).address_info["address"]
+ gcs_client = GcsClient(address=addr)
+ gcs_client.internal_kv_put(b"ABC", b"XYZ", True, None)
+ ray.shutdown()
+ redis_addr = os.environ["RAY_REDIS_ADDRESS"]
+ host, port = redis_addr.split(":")
+ if os.environ.get("TEST_EXTERNAL_REDIS_REPLICAS", "1") != "1":
+ cli = redis.RedisCluster(host, int(port))
+ else:
+ cli = redis.Redis(host, int(port))
+
+ assert set(cli.keys()) == {b"c1", b"c2"}
+ gcs_utils.cleanup_redis_storage(host, int(port), "", False, "c1")
+ assert set(cli.keys()) == {b"c2"}
+ gcs_utils.cleanup_redis_storage(host, int(port), "", False, "c2")
+ assert len(cli.keys()) == 0
+
+
if __name__ == "__main__":
import sys
diff --git a/python/ray/tests/test_multiprocessing.py b/python/ray/tests/test_multiprocessing.py
index 7ca5f9401ffc2..07051b1ef3659 100644
--- a/python/ray/tests/test_multiprocessing.py
+++ b/python/ray/tests/test_multiprocessing.py
@@ -6,6 +6,7 @@
import time
import random
from collections import defaultdict
+import warnings
import queue
import math
@@ -508,6 +509,28 @@ def f(args):
result_iter.next()
+@pytest.mark.filterwarnings(
+ "default:Passing a non-iterable argument:ray.util.annotations.RayDeprecationWarning"
+)
+def test_warn_on_non_iterable_imap_or_imap_unordered(pool):
+ def fn(_):
+ pass
+
+ non_iterable = 3
+
+ with warnings.catch_warnings(record=True) as w:
+ pool.imap(fn, non_iterable)
+ assert any(
+ "Passing a non-iterable argument" in str(warning.message) for warning in w
+ )
+
+ with warnings.catch_warnings(record=True) as w:
+ pool.imap_unordered(fn, non_iterable)
+ assert any(
+ "Passing a non-iterable argument" in str(warning.message) for warning in w
+ )
+
+
@pytest.mark.parametrize("use_iter", [True, False])
def test_imap_unordered(pool_4_processes, use_iter):
def f(args):
diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py
index f30791b1bc139..1be1619a7ecc7 100644
--- a/python/ray/tests/test_state_api.py
+++ b/python/ray/tests/test_state_api.py
@@ -1,3 +1,4 @@
+import os
import time
import json
import sys
@@ -12,6 +13,7 @@
import yaml
from click.testing import CliRunner
+from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
import ray
import ray.dashboard.consts as dashboard_consts
import ray._private.state as global_state
@@ -801,12 +803,13 @@ async def test_api_manager_list_tasks(state_api_manager):
)
]
result = await state_api_manager.list_tasks(option=create_api_options())
- data_source_client.get_all_task_info.assert_any_await(timeout=DEFAULT_RPC_TIMEOUT)
+ data_source_client.get_all_task_info.assert_any_await(
+ timeout=DEFAULT_RPC_TIMEOUT, job_id=None
+ )
data = result.result
data = data
assert len(data) == 2
assert result.total == 2
-
verify_schema(TaskState, data[0])
assert data[0]["node_id"] == node_id.hex()
verify_schema(TaskState, data[1])
@@ -862,6 +865,118 @@ async def test_api_manager_list_tasks(state_api_manager):
assert len(result.result) == 1
+@pytest.mark.skipif(
+ sys.version_info < (3, 8, 0),
+ reason=("Not passing in CI although it works locally. Will handle it later."),
+)
+@pytest.mark.asyncio
+async def test_api_manager_list_tasks_events(state_api_manager):
+ data_source_client = state_api_manager.data_source_client
+
+ node_id = NodeID.from_random()
+ data_source_client.get_all_task_info = AsyncMock()
+ id = b"1234"
+ func_or_class = "f"
+
+ # Generate a task event.
+
+ task_info = TaskInfoEntry(
+ task_id=id,
+ name=func_or_class,
+ func_or_class_name=func_or_class,
+ type=TaskType.NORMAL_TASK,
+ )
+ current = time.time_ns()
+ second = int(1e9)
+ state_updates = TaskStateUpdate(
+ node_id=node_id.binary(),
+ pending_args_avail_ts=current,
+ submitted_to_worker_ts=current + second,
+ running_ts=current + (2 * second),
+ finished_ts=current + (3 * second),
+ )
+
+ """
+ Test basic.
+ """
+ events = TaskEvents(
+ task_id=id,
+ job_id=b"0001",
+ attempt_number=0,
+ task_info=task_info,
+ state_updates=state_updates,
+ )
+ data_source_client.get_all_task_info.side_effect = [generate_task_data([events])]
+ result = await state_api_manager.list_tasks(option=create_api_options(detail=True))
+ result = result.result[0]
+ assert "events" in result
+ assert result["state"] == "FINISHED"
+ expected_events = [
+ {
+ "state": "PENDING_ARGS_AVAIL",
+ "created_ms": current // 1e6,
+ },
+ {
+ "state": "SUBMITTED_TO_WORKER",
+ "created_ms": (current + second) // 1e6,
+ },
+ {
+ "state": "RUNNING",
+ "created_ms": (current + 2 * second) // 1e6,
+ },
+ {
+ "state": "FINISHED",
+ "created_ms": (current + 3 * second) // 1e6,
+ },
+ ]
+ for actual, expected in zip(result["events"], expected_events):
+ assert actual == expected
+ assert result["start_time_ms"] == (current + 2 * second) // 1e6
+ assert result["end_time_ms"] == (current + 3 * second) // 1e6
+
+ """
+ Test only start_time_ms is updated.
+ """
+ state_updates = TaskStateUpdate(
+ node_id=node_id.binary(),
+ pending_args_avail_ts=current,
+ submitted_to_worker_ts=current + second,
+ running_ts=current + (2 * second),
+ )
+ events = TaskEvents(
+ task_id=id,
+ job_id=b"0001",
+ attempt_number=0,
+ task_info=task_info,
+ state_updates=state_updates,
+ )
+ data_source_client.get_all_task_info.side_effect = [generate_task_data([events])]
+ result = await state_api_manager.list_tasks(option=create_api_options(detail=True))
+ result = result.result[0]
+ assert result["start_time_ms"] == (current + 2 * second) // 1e6
+ assert result["end_time_ms"] is None
+
+ """
+ Test None of start & end time is updated.
+ """
+ state_updates = TaskStateUpdate(
+ pending_args_avail_ts=current,
+ submitted_to_worker_ts=current + second,
+ )
+ events = TaskEvents(
+ task_id=id,
+ job_id=b"0001",
+ attempt_number=0,
+ task_info=task_info,
+ state_updates=state_updates,
+ )
+ data_source_client.get_all_task_info.side_effect = [generate_task_data([events])]
+ result = await state_api_manager.list_tasks(option=create_api_options(detail=True))
+ result = result.result[0]
+ assert result["start_time_ms"] is None
+ assert result["end_time_ms"] is None
+
+
@pytest.mark.skipif(
sys.version_info < (3, 8, 0),
reason=("Not passing in CI although it works locally. Will handle it later."),
@@ -2039,7 +2154,7 @@ def verify():
waiting_for_execution = len(
list(
filter(
- lambda task: task["scheduling_state"] == "SUBMITTED_TO_WORKER",
+ lambda task: task["state"] == "SUBMITTED_TO_WORKER",
tasks,
)
)
@@ -2048,7 +2163,7 @@ def verify():
scheduled = len(
list(
filter(
- lambda task: task["scheduling_state"] == "PENDING_NODE_ASSIGNMENT",
+ lambda task: task["state"] == "PENDING_NODE_ASSIGNMENT",
tasks,
)
)
@@ -2057,7 +2172,7 @@ def verify():
waiting_for_dep = len(
list(
filter(
- lambda task: task["scheduling_state"] == "PENDING_ARGS_AVAIL",
+ lambda task: task["state"] == "PENDING_ARGS_AVAIL",
tasks,
)
)
@@ -2066,7 +2181,7 @@ def verify():
running = len(
list(
filter(
- lambda task: task["scheduling_state"] == "RUNNING",
+ lambda task: task["state"] == "RUNNING",
tasks,
)
)
@@ -2080,22 +2195,76 @@ def verify():
assert get_task_data == task
# Test node id.
- tasks = list_tasks(
- filters=[("scheduling_state", "=", "PENDING_NODE_ASSIGNMENT")]
- )
+ tasks = list_tasks(filters=[("state", "=", "PENDING_NODE_ASSIGNMENT")])
for task in tasks:
assert task["node_id"] is None
- tasks = list_tasks(filters=[("scheduling_state", "=", "RUNNING")])
+ tasks = list_tasks(filters=[("state", "=", "RUNNING")])
for task in tasks:
assert task["node_id"] == node_id
+ tasks = list_tasks(filters=[("job_id", "=", job_id)])
+ for task in tasks:
+ assert task["job_id"] == job_id
+
return True
wait_for_condition(verify)
print(list_tasks())
+def test_pg_worker_id_tasks(shutdown_only):
+ ray.init(num_cpus=1)
+ pg = ray.util.placement_group(bundles=[{"CPU": 1}])
+ pg.wait()
+
+ @ray.remote
+ def f():
+ pass
+
+ @ray.remote
+ class A:
+ def ready(self):
+ return os.getpid()
+
+ ray.get(
+ f.options(
+ scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=pg)
+ ).remote()
+ )
+
+ def verify():
+ tasks = list_tasks(detail=True)
+ workers = list_workers(filters=[("worker_type", "=", "WORKER")])
+ assert len(tasks) == 1
+ assert len(workers) == 1
+
+ assert tasks[0]["placement_group_id"] == pg.id.hex()
+ assert tasks[0]["worker_id"] == workers[0]["worker_id"]
+
+ return True
+
+ wait_for_condition(verify)
+ print(list_tasks(detail=True))
+
+ a = A.options(
+ scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=pg)
+ ).remote()
+ pid = ray.get(a.ready.remote())
+
+ def verify():
+ actors = list_actors(detail=True)
+ workers = list_workers(detail=True, filters=[("pid", "=", pid)])
+ assert len(actors) == 1
+ assert len(workers) == 1
+
+ assert actors[0]["placement_group_id"] == pg.id.hex()
+ return True
+
+ wait_for_condition(verify)
+ print(list_actors(detail=True))
+
+
def test_parent_task_id(shutdown_only):
"""Test parent task id set up properly"""
ray.init(num_cpus=2)
@@ -2111,7 +2280,7 @@ def parent():
ray.get(parent.remote())
def verify():
- tasks = list_tasks()
+ tasks = list_tasks(detail=True)
assert len(tasks) == 2, "Expect 2 tasks to finished"
parent_task_id = None
child_parent_task_id = None
@@ -2145,7 +2314,7 @@ def verify(task_attempts):
assert len(task_attempts) == 3 # 2 retries + 1 initial run
for task_attempt in task_attempts:
assert task_attempt["job_id"] == job_id
- assert task_attempt["scheduling_state"] == "FAILED"
+ assert task_attempt["state"] == "FAILED"
assert task_attempt["node_id"] == node_id
assert {task_attempt["attempt_number"] for task_attempt in task_attempts} == {
@@ -2193,9 +2362,9 @@ def f():
def verify(task_attempts):
assert len(task_attempts) == 3
for task_attempt in task_attempts[:-1]:
- assert task_attempt["scheduling_state"] == "FAILED"
+ assert task_attempt["state"] == "FAILED"
- task_attempts[-1]["scheduling_state"] == "FINISHED"
+ task_attempts[-1]["state"] == "FINISHED"
assert {task_attempt["attempt_number"] for task_attempt in task_attempts} == {
0,
@@ -2237,7 +2406,7 @@ def verify():
len(
list(
filter(
- lambda task: task["scheduling_state"] == "SUBMITTED_TO_WORKER",
+ lambda task: task["state"] == "SUBMITTED_TO_WORKER",
tasks,
)
)
@@ -2248,8 +2417,7 @@ def verify():
len(
list(
filter(
- lambda task: task["scheduling_state"]
- == "PENDING_NODE_ASSIGNMENT",
+ lambda task: task["state"] == "PENDING_NODE_ASSIGNMENT",
tasks,
)
)
@@ -2260,7 +2428,7 @@ def verify():
len(
list(
filter(
- lambda task: task["scheduling_state"] == "PENDING_ARGS_AVAIL",
+ lambda task: task["state"] == "PENDING_ARGS_AVAIL",
tasks,
)
)
@@ -2271,7 +2439,7 @@ def verify():
len(
list(
filter(
- lambda task: task["scheduling_state"] == "RUNNING",
+ lambda task: task["state"] == "RUNNING",
tasks,
)
)
@@ -2370,6 +2538,10 @@ def ready(self):
assert output == list_actors(limit=2)
+@pytest.mark.skipif(
+ sys.platform == "win32",
+ reason="Failed on Windows",
+)
def test_network_failure(shutdown_only):
"""When the request fails due to network failure,
verifies it raises an exception."""
@@ -3134,7 +3306,6 @@ def test_core_state_api_usage_tags(shutdown_only):
if __name__ == "__main__":
- import os
import sys
if os.environ.get("PARALLEL_CI"):
diff --git a/python/ray/tests/test_task_events.py b/python/ray/tests/test_task_events.py
index 28e49933b6aab..87d16772a363b 100644
--- a/python/ray/tests/test_task_events.py
+++ b/python/ray/tests/test_task_events.py
@@ -1,9 +1,12 @@
from collections import defaultdict
from typing import Dict
+import pytest
+import time
import ray
from ray._private.test_utils import (
raw_metrics,
+ run_string_as_driver_nonblocking,
wait_for_condition,
)
from ray.experimental.state.api import list_tasks
@@ -14,6 +17,7 @@
"task_events_report_interval_ms": 100,
"metrics_report_interval_ms": 200,
"enable_timeline": False,
+ "gcs_mark_task_failed_on_job_done_delay_ms": 1000,
}
@@ -77,7 +81,6 @@ def verify():
def test_fault_tolerance_parent_failed(shutdown_only):
ray.init(num_cpus=4, _system_config=_SYSTEM_CONFIG)
- import time
# Each parent task spins off 2 child task, where each child spins off
# 1 grand_child task.
@@ -109,7 +112,7 @@ def verify():
)
print(tasks)
for task in tasks:
- assert task["scheduling_state"] == "FAILED"
+ assert task["state"] == "FAILED"
return True
@@ -118,3 +121,298 @@ def verify():
timeout=10,
retry_interval_ms=500,
)
+
+
+def test_fault_tolerance_job_failed(shutdown_only):
+ ray.init(num_cpus=8, _system_config=_SYSTEM_CONFIG)
+ script = """
+import ray
+import time
+
+ray.init("auto")
+NUM_CHILD = 2
+
+@ray.remote
+def grandchild():
+ time.sleep(999)
+
+@ray.remote
+def child():
+ ray.get(grandchild.remote())
+
+@ray.remote
+def finished_child():
+ ray.put(1)
+ return
+
+@ray.remote
+def parent():
+ children = [child.remote() for _ in range(NUM_CHILD)]
+ finished_children = ray.get([finished_child.remote() for _ in range(NUM_CHILD)])
+ ray.get(children)
+
+ray.get(parent.remote())
+
+"""
+ proc = run_string_as_driver_nonblocking(script)
+
+ def verify():
+ tasks = list_tasks()
+ print(tasks)
+ assert len(tasks) == 7, (
+ "Incorrect number of tasks are reported. "
+ "Expected length: 1 parent + 2 finished child + 2 failed child + "
+ "2 failed grandchild tasks"
+ )
+ return True
+
+ wait_for_condition(
+ verify,
+ timeout=10,
+ retry_interval_ms=500,
+ )
+
+ proc.kill()
+
+ def verify():
+ tasks = list_tasks()
+ assert len(tasks) == 7, (
+ "Incorrect number of tasks are reported. "
+ "Expected length: 1 parent + 2 finished child + 2 failed child + "
+ "2 failed grandchild tasks"
+ )
+ for task in tasks:
+ if "finished" in task["func_or_class_name"]:
+ assert (
+ task["scheduling_state"] == "FINISHED"
+ ), f"task {task['func_or_class_name']} has wrong state"
+ else:
+ assert (
+ task["scheduling_state"] == "FAILED"
+ ), f"task {task['func_or_class_name']} has wrong state"
+
+ return True
+
+ wait_for_condition(
+ verify,
+ timeout=10,
+ retry_interval_ms=500,
+ )
+
+
+@ray.remote
+def task_finish_child():
+ pass
+
+
+@ray.remote
+def task_sleep_child():
+ time.sleep(999)
+
+
+@ray.remote
+class ChildActor:
+ def children(self):
+ ray.get(task_finish_child.remote())
+ ray.get(task_sleep_child.remote())
+
+
+@ray.remote
+class Actor:
+ def fail_parent(self):
+ task_finish_child.remote()
+ task_sleep_child.remote()
+ raise ValueError("expected to fail.")
+
+ def child_actor(self):
+ a = ChildActor.remote()
+ try:
+ ray.get(a.children.remote(), timeout=2)
+ except ray.exceptions.GetTimeoutError:
+ pass
+ raise ValueError("expected to fail.")
+
+
+def test_fault_tolerance_actor_tasks_failed(shutdown_only):
+ ray.init(_system_config=_SYSTEM_CONFIG)
+ # Test actor tasks
+ with pytest.raises(ray.exceptions.RayTaskError):
+ a = Actor.remote()
+ ray.get(a.fail_parent.remote())
+
+ def verify():
+ tasks = list_tasks()
+ assert (
+ len(tasks) == 4
+ ), "1 creation task + 1 actor tasks + 2 normal tasks run by the actor tasks"
+ for task in tasks:
+ if "finish" in task["name"] or "__init__" in task["name"]:
+ assert task["scheduling_state"] == "FINISHED", task
+ else:
+ assert task["scheduling_state"] == "FAILED", task
+
+ return True
+
+ wait_for_condition(
+ verify,
+ timeout=10,
+ retry_interval_ms=500,
+ )
+
+
+def test_fault_tolerance_nested_actors_failed(shutdown_only):
+ ray.init(_system_config=_SYSTEM_CONFIG)
+
+ # Test nested actor tasks
+ with pytest.raises(ray.exceptions.RayTaskError):
+ a = Actor.remote()
+ ray.get(a.child_actor.remote())
+
+ def verify():
+ tasks = list_tasks()
+ assert len(tasks) == 6, (
+ "2 creation task + 1 parent actor task + 1 child actor task "
+ " + 2 normal tasks run by child actor"
+ )
+ for task in tasks:
+ if "finish" in task["name"] or "__init__" in task["name"]:
+ assert task["scheduling_state"] == "FINISHED", task
+ else:
+ assert task["scheduling_state"] == "FAILED", task
+
+ return True
+
+ wait_for_condition(
+ verify,
+ timeout=10,
+ retry_interval_ms=500,
+ )
+
+
+@pytest.mark.parametrize("death_list", [["A"], ["Abb", "C"], ["Abb", "Ca", "A"]])
+def test_fault_tolerance_advanced_tree(shutdown_only, death_list):
+ import asyncio
+
+ # Some constants
+ NORMAL_TASK = 0
+ ACTOR_TASK = 1
+
+ # Root should always be finish
+ execution_graph = {
+ "root": [
+ (NORMAL_TASK, "A"),
+ (ACTOR_TASK, "B"),
+ (NORMAL_TASK, "C"),
+ (ACTOR_TASK, "D"),
+ ],
+ "A": [(ACTOR_TASK, "Aa"), (NORMAL_TASK, "Ab")],
+ "C": [(ACTOR_TASK, "Ca"), (NORMAL_TASK, "Cb")],
+ "D": [
+ (NORMAL_TASK, "Da"),
+ (NORMAL_TASK, "Db"),
+ (ACTOR_TASK, "Dc"),
+ (ACTOR_TASK, "Dd"),
+ ],
+ "Aa": [],
+ "Ab": [(ACTOR_TASK, "Aba"), (NORMAL_TASK, "Abb"), (NORMAL_TASK, "Abc")],
+ "Ca": [(ACTOR_TASK, "Caa"), (NORMAL_TASK, "Cab")],
+ "Abb": [(NORMAL_TASK, "Abba")],
+ "Abc": [],
+ "Abba": [(NORMAL_TASK, "Abbaa"), (ACTOR_TASK, "Abbab")],
+ "Abbaa": [(NORMAL_TASK, "Abbaaa"), (ACTOR_TASK, "Abbaab")],
+ }
+
+ ray.init(_system_config=_SYSTEM_CONFIG)
+
+ @ray.remote
+ class Killer:
+ def __init__(self, death_list, wait_time):
+ self.idx_ = 0
+ self.death_list_ = death_list
+ self.wait_time_ = wait_time
+ self.start_ = time.time()
+
+ async def next_to_kill(self):
+ now = time.time()
+ if now - self.start_ < self.wait_time_:
+ # Sleep until killing starts...
+ time.sleep(self.wait_time_ - (now - self.start_))
+
+ # if no more tasks to kill - simply sleep to keep all running tasks blocked.
+ while self.idx_ >= len(self.death_list_):
+ await asyncio.sleep(999)
+
+ to_kill = self.death_list_[self.idx_]
+ print(f"{to_kill} to be killed")
+ return to_kill
+
+ async def advance_next(self):
+ self.idx_ += 1
+
+ def run_children(my_name, killer, execution_graph):
+ children = execution_graph.get(my_name, [])
+ for task_type, child_name in children:
+ if task_type == NORMAL_TASK:
+ task.options(name=child_name).remote(
+ child_name, killer, execution_graph
+ )
+ else:
+ a = Actor.remote()
+ a.actor_task.options(name=child_name).remote(
+ child_name, killer, execution_graph
+ )
+
+ # Block until killed
+ while True:
+ to_fail = ray.get(killer.next_to_kill.remote())
+ if to_fail == my_name:
+ ray.get(killer.advance_next.remote())
+ raise ValueError(f"{my_name} expected to fail")
+
+ @ray.remote
+ class Actor:
+ def actor_task(self, my_name, killer, execution_graph):
+ run_children(my_name, killer, execution_graph)
+
+ @ray.remote
+ def task(my_name, killer, execution_graph):
+ run_children(my_name, killer, execution_graph)
+
+ killer = Killer.remote(death_list, 5)
+
+ task.options(name="root").remote("root", killer, execution_graph)
+
+ def verify():
+ tasks = list_tasks()
+ target_tasks = filter(
+ lambda task: "__init__" not in task["name"]
+ and "Killer" not in task["name"],
+ tasks,
+ )
+
+ # Calculate tasks that should have failed
+ dead_tasks = set()
+
+ def add_death_tasks_recur(task, execution_graph, dead_tasks):
+ children = execution_graph.get(task, [])
+ dead_tasks.add(task)
+
+ for _, child in children:
+ add_death_tasks_recur(child, execution_graph, dead_tasks)
+
+ for task in death_list:
+ add_death_tasks_recur(task, execution_graph, dead_tasks)
+
+ for task in target_tasks:
+ if task["name"] in dead_tasks:
+ assert task["scheduling_state"] == "FAILED", task["name"]
+ else:
+ assert task["scheduling_state"] == "RUNNING", task["name"]
+
+ return True
+
+ wait_for_condition(
+ verify,
+ timeout=15,
+ retry_interval_ms=500,
+ )
diff --git a/python/ray/train/BUILD b/python/ray/train/BUILD
index 8fe3a866bebd7..6c8a9026ee949 100644
--- a/python/ray/train/BUILD
+++ b/python/ray/train/BUILD
@@ -536,6 +536,14 @@ py_test(
deps = [":train_lib"]
)
+py_test(
+ name = "test_e2e_wandb_integration",
+ size = "small",
+ srcs = ["tests/test_e2e_wandb_integration.py"],
+ tags = ["team:ml", "exclusive"],
+ deps = [":train_lib"]
+)
+
py_test(
name = "test_worker_group",
size = "medium",
diff --git a/python/ray/train/base_trainer.py b/python/ray/train/base_trainer.py
index 181f8bab95eb9..dcd6495c9f3ba 100644
--- a/python/ray/train/base_trainer.py
+++ b/python/ray/train/base_trainer.py
@@ -1,4 +1,5 @@
import abc
+import copy
import inspect
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
@@ -141,6 +142,10 @@ def training_loop(self):
_handles_checkpoint_freq: bool = False
_handles_checkpoint_at_end: bool = False
+ # fields to propagate to Tuner param_space.
+ # See `BaseTrainer._extract_fields_for_tuner_param_space` for more details.
+ _fields_for_tuner_param_space = []
+
def __init__(
self,
*,
@@ -350,8 +355,11 @@ def fit(self) -> Result:
from ray.tune.error import TuneError
trainable = self.as_trainable()
+ param_space = self._extract_fields_for_tuner_param_space()
- tuner = Tuner(trainable=trainable, run_config=self.run_config)
+ tuner = Tuner(
+ trainable=trainable, param_space=param_space, run_config=self.run_config
+ )
result_grid = tuner.fit()
assert len(result_grid) == 1
try:
@@ -362,6 +370,23 @@ def fit(self) -> Result:
raise TrainingFailedError from e
return result
+ def _extract_fields_for_tuner_param_space(self) -> Dict:
+ """Extracts fields to be included in `Tuner.param_space`.
+
+ This is needed to leverage the full logging/integration offerings from Tune.
+ For example, `param_space` is logged automatically to wandb integration.
+
+ Currently only done for `train_loop_config`.
+
+ Returns:
+ A dictionary that should be passed to Tuner.param_space.
+ """
+ result = {}
+ for key in self._fields_for_tuner_param_space:
+ if key in self._param_dict.keys():
+ result[key] = copy.deepcopy(self._param_dict[key])
+ return result
+
def _generate_trainable_cls(self) -> Type["Trainable"]:
"""Generate the base Trainable class.
diff --git a/python/ray/train/constants.py b/python/ray/train/constants.py
index 5eb91f6abb49c..abb9e3cffb26c 100644
--- a/python/ray/train/constants.py
+++ b/python/ray/train/constants.py
@@ -68,11 +68,8 @@
# as Trainable)
DISABLE_LAZY_CHECKPOINTING_ENV = "TRAIN_DISABLE_LAZY_CHECKPOINTING"
-# Default NCCL_SOCKET_IFNAME.
-# Use ethernet when possible.
-# NCCL_SOCKET_IFNAME does a prefix match so "ens3" or "ens5" will match with
-# "en".
-DEFAULT_NCCL_SOCKET_IFNAME = "en,eth,bond"
+# Blacklist virtualized networking.
+DEFAULT_NCCL_SOCKET_IFNAME = "^lo,docker,veth"
# Key for AIR Checkpoint metadata in TrainingResult metadata
CHECKPOINT_METADATA_KEY = "checkpoint_metadata"
diff --git a/python/ray/train/data_parallel_trainer.py b/python/ray/train/data_parallel_trainer.py
index 8b5983dbae322..d1e826ffea4f9 100644
--- a/python/ray/train/data_parallel_trainer.py
+++ b/python/ray/train/data_parallel_trainer.py
@@ -114,7 +114,7 @@ def train_loop_per_worker():
Any returns from the ``train_loop_per_worker`` will be discarded and not
used or persisted anywhere.
- **How do I use ``DataParallelTrainer`` or any of its subclasses?**
+ **How do I use DataParallelTrainer or any of its subclasses?**
Example:
@@ -136,7 +136,7 @@ def train_loop_for_worker():
)
result = trainer.fit()
- **How do I develop on top of ``DataParallelTrainer``?**
+ **How do I develop on top of DataParallelTrainer?**
In many cases, using DataParallelTrainer directly is sufficient to execute
functions on multiple actors.
@@ -241,6 +241,10 @@ def __init__(self, train_loop_per_worker, my_backend_config:
WILDCARD_KEY: DatasetConfig(split=False),
}
+ _fields_for_tuner_param_space = BaseTrainer._fields_for_tuner_param_space + [
+ "train_loop_config"
+ ]
+
def __init__(
self,
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
diff --git a/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py b/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py
index 46ea6ab3947ae..90846eb84824e 100644
--- a/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py
+++ b/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py
@@ -69,7 +69,6 @@ def train_func(config):
epochs = config.get("epochs", 3)
model = resnet18()
- model = train.torch.prepare_model(model)
# Create optimizer.
optimizer_config = {
@@ -98,6 +97,8 @@ def train_func(config):
checkpoint_epoch = checkpoint_dict["epoch"]
starting_epoch = checkpoint_epoch + 1
+ model = train.torch.prepare_model(model)
+
# Load in training and validation data.
transform_train = transforms.Compose(
[
diff --git a/python/ray/train/tests/test_e2e_wandb_integration.py b/python/ray/train/tests/test_e2e_wandb_integration.py
new file mode 100644
index 0000000000000..9870d40b2bc50
--- /dev/null
+++ b/python/ray/train/tests/test_e2e_wandb_integration.py
@@ -0,0 +1,80 @@
+"""
+If a user uses Trainer API directly with wandb integration, they expect to see
+* train_loop_config to show up in wandb.config.
+
+This test uses mocked call into wandb API.
+"""
+
+import pytest
+
+import ray
+from ray.air import RunConfig, ScalingConfig
+from ray.air.integrations.wandb import WANDB_ENV_VAR
+from ray.air.tests.mocked_wandb_integration import WandbTestExperimentLogger
+from ray.train.examples.pytorch.torch_linear_example import (
+ train_func as linear_train_func,
+)
+from ray.train.torch import TorchTrainer
+
+
+@pytest.fixture
+def ray_start_4_cpus():
+ address_info = ray.init(num_cpus=4)
+ yield address_info
+ # The code after the yield will run as teardown code.
+ ray.shutdown()
+
+
+CONFIG = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": 3}
+
+
+@pytest.mark.parametrize("with_train_loop_config", (True, False))
+def test_trainer_wandb_integration(
+ ray_start_4_cpus, with_train_loop_config, monkeypatch
+):
+ monkeypatch.setenv(WANDB_ENV_VAR, "9012")
+
+ def train_func(config=None):
+ config = config or CONFIG
+ result = linear_train_func(config)
+ assert len(result) == config["epochs"]
+ assert result[-1]["loss"] < result[0]["loss"]
+
+ scaling_config = ScalingConfig(num_workers=2)
+
+ logger = WandbTestExperimentLogger(project="test_project")
+ if with_train_loop_config:
+ trainer = TorchTrainer(
+ train_loop_per_worker=train_func,
+ train_loop_config=CONFIG,
+ scaling_config=scaling_config,
+ run_config=RunConfig(callbacks=[logger]),
+ )
+ else:
+ trainer = TorchTrainer(
+ train_loop_per_worker=train_func,
+ scaling_config=scaling_config,
+ run_config=RunConfig(callbacks=[logger]),
+ )
+ trainer.fit()
+ # We use local actor for mocked logger.
+ # As a result, `._wandb`, `.config` and `.queue` are
+ # guaranteed to be available by the time `trainer.fit()` returns.
+ # This is so because they are generated in corresponding initializer
+ # in a sync fashion.
+ config = list(logger.trial_processes.values())[0]._wandb.config.queue.get(
+ timeout=10
+ )
+
+ if with_train_loop_config:
+ assert "train_loop_config" in config
+ else:
+ assert "train_loop_config" not in config
+
+
+if __name__ == "__main__":
+ import sys
+
+ import pytest
+
+ sys.exit(pytest.main(["-v", "-x", __file__]))
diff --git a/python/ray/train/torch/torch_trainer.py b/python/ray/train/torch/torch_trainer.py
index 8323d0ad5f8c7..a21a6a93adb09 100644
--- a/python/ray/train/torch/torch_trainer.py
+++ b/python/ray/train/torch/torch_trainer.py
@@ -93,6 +93,40 @@ def train_loop_per_worker():
To save a model to use for the ``TorchPredictor``, you must save it under the
"model" kwarg in ``Checkpoint`` passed to ``session.report()``.
+ .. note::
+ When you wrap the ``model`` with ``prepare_model``, the keys of its
+ ``state_dict`` are prefixed by ``module.``. For example,
+ ``layer1.0.bn1.bias`` becomes ``module.layer1.0.bn1.bias``.
+ However, when saving ``model`` through ``session.report()``
+ all ``module.`` prefixes are stripped.
+ As a result, when you load from a saved checkpoint, make sure that
+ you first load ``state_dict`` to the model
+ before calling ``prepare_model``.
+ Otherwise, you will run into errors like
+ ``Error(s) in loading state_dict for DistributedDataParallel:
+ Missing key(s) in state_dict: "module.conv1.weight", ...``. See snippet below.
+
+ .. testcode::
+
+ from torchvision.models import resnet18
+ from ray.air import session
+ from ray.air.checkpoint import Checkpoint
+ import ray.train as train
+
+ def train_func():
+ ...
+ model = resnet18()
+ model = train.torch.prepare_model(model)
+ for epoch in range(3):
+ ...
+ ckpt = Checkpoint.from_dict({
+ "epoch": epoch,
+ "model": model.state_dict(),
+ # "model": model.module.state_dict(),
+ # ** The above two are equivalent **
+ })
+ session.report({"foo": "bar"}, ckpt)
+
Example:
.. testcode::
diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py
index 2026107a9c43e..da57fe46c949d 100644
--- a/python/ray/tune/analysis/experiment_analysis.py
+++ b/python/ray/tune/analysis/experiment_analysis.py
@@ -30,10 +30,7 @@
TRAINING_ITERATION,
)
from ray.tune.experiment import Trial
-from ray.tune.execution.trial_runner import (
- _find_newest_experiment_checkpoint,
- _load_trial_from_checkpoint,
-)
+from ray.tune.execution.trial_runner import _find_newest_experiment_checkpoint
from ray.tune.trainable.util import TrainableUtil
from ray.tune.utils.util import unflattened_lookup
@@ -143,13 +140,10 @@ def _load_checkpoints_from_latest(self, latest_checkpoint: List[str]) -> None:
if "checkpoints" not in experiment_state:
raise TuneError("Experiment state invalid; no checkpoints found.")
+
self._checkpoints_and_paths += [
- (_decode_checkpoint_from_experiment_state(cp), Path(path).parent)
- for cp in experiment_state["checkpoints"]
+ (cp, Path(path).parent) for cp in experiment_state["checkpoints"]
]
- self._checkpoints_and_paths = sorted(
- self._checkpoints_and_paths, key=lambda tup: tup[0]["trial_id"]
- )
def _get_latest_checkpoint(self, experiment_checkpoint_path: Path) -> List[str]:
# Case 1: Dir specified, find latest checkpoint.
@@ -798,12 +792,10 @@ def _get_trial_paths(self) -> List[str]:
"out of sync, as checkpointing is periodic."
)
self.trials = []
- _trial_paths = []
- for checkpoint, path in self._checkpoints_and_paths:
+ for trial_json_state, path in self._checkpoints_and_paths:
try:
- trial = _load_trial_from_checkpoint(
- checkpoint, stub=True, new_local_dir=str(path)
- )
+ trial = Trial.from_json_state(trial_json_state, stub=True)
+ trial.local_dir = str(path)
except Exception:
logger.warning(
f"Could not load trials from experiment checkpoint. "
@@ -814,7 +806,9 @@ def _get_trial_paths(self) -> List[str]:
)
continue
self.trials.append(trial)
- _trial_paths.append(str(trial.logdir))
+
+ self.trials.sort(key=lambda trial: trial.trial_id)
+ _trial_paths = [str(trial.logdir) for trial in self.trials]
if not _trial_paths:
raise TuneError("No trials found.")
@@ -882,7 +876,3 @@ def make_stub_if_needed(trial: Trial) -> Trial:
state["trials"] = [make_stub_if_needed(t) for t in state["trials"]]
return state
-
-
-def _decode_checkpoint_from_experiment_state(cp: Union[str, dict]) -> dict:
- return json.loads(cp, cls=TuneFunctionDecoder) if isinstance(cp, str) else cp
diff --git a/python/ray/tune/execution/trial_runner.py b/python/ray/tune/execution/trial_runner.py
index aa7817dee87e3..0479dd6eeb4df 100644
--- a/python/ray/tune/execution/trial_runner.py
+++ b/python/ray/tune/execution/trial_runner.py
@@ -1,6 +1,6 @@
from collections import defaultdict
from dataclasses import dataclass
-from typing import Any, DefaultDict, List, Mapping, Optional, Union, Tuple, Set
+from typing import DefaultDict, List, Optional, Union, Tuple, Set
import click
from datetime import datetime
@@ -67,65 +67,6 @@ def _find_newest_experiment_checkpoint(ckpt_dir) -> Optional[str]:
return max(full_paths)
-def _load_trial_from_checkpoint(
- trial_cp: dict, stub: bool = False, new_local_dir: Optional[str] = None
-) -> Trial:
- """Create a Trial from the state stored in the experiment checkpoint.
-
- Args:
- trial_cp: Trial state from the experiment checkpoint, which is loaded
- from the trial's `Trial.get_json_state`.
- stub: Whether or not to validate the trainable name when creating the Trial.
- Used for testing purposes for creating mocks.
- new_local_dir: If set, this `local_dir` will overwrite what's saved in the
- `trial_cp` state. Used in the case that the trial directory has moved.
- The Trial `logdir` and the persistent trial checkpoints will have their
- paths updated relative to this new directory.
-
- Returns:
- new_trial: New trial with state loaded from experiment checkpoint
- """
- new_trial = Trial(
- trial_cp["trainable_name"],
- stub=stub,
- _setup_default_resource=False,
- )
- if new_local_dir:
- trial_cp["local_dir"] = new_local_dir
- new_trial.__setstate__(trial_cp)
- new_trial.refresh_default_resource_request()
- return new_trial
-
-
-def _load_trials_from_experiment_checkpoint(
- experiment_checkpoint: Mapping[str, Any],
- stub: bool = False,
- new_local_dir: Optional[str] = None,
-) -> List[Trial]:
- """Create trial objects from experiment checkpoint.
-
- Given an experiment checkpoint (TrialRunner state dict), return
- list of trials. See `_ExperimentCheckpointManager.checkpoint` for
- what's saved in the TrialRunner state dict.
- """
- checkpoints = [
- json.loads(cp, cls=TuneFunctionDecoder) if isinstance(cp, str) else cp
- for cp in experiment_checkpoint["checkpoints"]
- ]
-
- trials = []
- for trial_cp in checkpoints:
- trials.append(
- _load_trial_from_checkpoint(
- trial_cp,
- stub=stub,
- new_local_dir=new_local_dir,
- )
- )
-
- return trials
-
-
@dataclass
class _ResumeConfig:
resume_unfinished: bool = True
@@ -154,17 +95,16 @@ class _ExperimentCheckpointManager:
def __init__(
self,
- checkpoint_dir: str,
+ local_checkpoint_dir: str,
checkpoint_period: Union[int, float, str],
start_time: float,
session_str: str,
syncer: Syncer,
sync_trial_checkpoints: bool,
- local_dir: str,
- remote_dir: str,
+ remote_checkpoint_dir: str,
sync_every_n_trial_checkpoints: Optional[int] = None,
):
- self._checkpoint_dir = checkpoint_dir
+ self._local_checkpoint_dir = local_checkpoint_dir
self._auto_checkpoint_enabled = checkpoint_period == "auto"
if self._auto_checkpoint_enabled:
self._checkpoint_period = 10.0 # Initial value
@@ -176,8 +116,7 @@ def __init__(
self._syncer = syncer
self._sync_trial_checkpoints = sync_trial_checkpoints
- self._local_dir = local_dir
- self._remote_dir = remote_dir
+ self._remote_checkpoint_dir = remote_checkpoint_dir
self._last_checkpoint_time = 0.0
self._last_sync_time = 0.0
@@ -225,7 +164,7 @@ def checkpoint(
Args:
force: Forces a checkpoint despite checkpoint_period.
"""
- if not self._checkpoint_dir:
+ if not self._local_checkpoint_dir:
return
force = force or self._should_force_cloud_sync
@@ -243,12 +182,14 @@ def _serialize_and_write():
"timestamp": self._last_checkpoint_time,
},
}
- tmp_file_name = os.path.join(self._checkpoint_dir, ".tmp_checkpoint")
+ tmp_file_name = os.path.join(self._local_checkpoint_dir, ".tmp_checkpoint")
with open(tmp_file_name, "w") as f:
json.dump(runner_state, f, indent=2, cls=TuneFunctionEncoder)
os.replace(tmp_file_name, checkpoint_file)
- search_alg.save_to_dir(self._checkpoint_dir, session_str=self._session_str)
+ search_alg.save_to_dir(
+ self._local_checkpoint_dir, session_str=self._session_str
+ )
checkpoint_time_start = time.monotonic()
with out_of_band_serialize_dataset():
@@ -274,14 +215,14 @@ def _serialize_and_write():
"`sync_timeout` in `SyncConfig`."
)
synced = self._syncer.sync_up(
- local_dir=self._local_dir,
- remote_dir=self._remote_dir,
+ local_dir=self._local_checkpoint_dir,
+ remote_dir=self._remote_checkpoint_dir,
exclude=exclude,
)
else:
synced = self._syncer.sync_up_if_needed(
- local_dir=self._local_dir,
- remote_dir=self._remote_dir,
+ local_dir=self._local_checkpoint_dir,
+ remote_dir=self._remote_checkpoint_dir,
exclude=exclude,
)
@@ -320,7 +261,7 @@ def _serialize_and_write():
)
self._last_checkpoint_time = time.time()
- return self._checkpoint_dir
+ return self._local_checkpoint_dir
@DeveloperAPI
@@ -350,14 +291,15 @@ class TrialRunner:
search_alg: SearchAlgorithm for generating
Trial objects.
scheduler: Defaults to FIFOScheduler.
- local_checkpoint_dir: Path where
- global checkpoints are stored and restored from.
- remote_checkpoint_dir: Remote path where
- global checkpoints are stored and restored from. Used
- if `resume` == REMOTE.
- sync_config: See `tune.py:run`.
- stopper: Custom class for stopping whole experiments. See
- ``Stopper``.
+ local_checkpoint_dir: Path where global experiment state checkpoints
+ are saved and restored from.
+ sync_config: See :class:`~ray.tune.syncer.SyncConfig`.
+ Within sync config, the `upload_dir` specifies cloud storage, and
+ experiment state checkpoints will be synced to the `remote_checkpoint_dir`:
+ `{sync_config.upload_dir}/{experiment_name}`.
+ experiment_dir_name: Experiment directory name.
+ See :class:`~ray.tune.experiment.Experiment`.
+ stopper: Custom class for stopping whole experiments. See ``Stopper``.
resume: see `tune.py:run`.
server_port: Port number for launching TuneServer.
fail_fast: Finishes as soon as a trial fails if True.
@@ -388,8 +330,8 @@ def __init__(
search_alg: Optional[SearchAlgorithm] = None,
scheduler: Optional[TrialScheduler] = None,
local_checkpoint_dir: Optional[str] = None,
- remote_checkpoint_dir: Optional[str] = None,
sync_config: Optional[SyncConfig] = None,
+ experiment_dir_name: Optional[str] = None,
stopper: Optional[Stopper] = None,
resume: Union[str, bool] = False,
server_port: Optional[int] = None,
@@ -436,11 +378,11 @@ def __init__(
# Manual override
self._max_pending_trials = int(max_pending_trials)
- sync_config = sync_config or SyncConfig()
+ self._sync_config = sync_config or SyncConfig()
self.trial_executor.setup(
max_pending_trials=self._max_pending_trials,
- trainable_kwargs={"sync_timeout": sync_config.sync_timeout},
+ trainable_kwargs={"sync_timeout": self._sync_config.sync_timeout},
)
self._metric = metric
@@ -485,9 +427,9 @@ def __init__(
if self._local_checkpoint_dir:
os.makedirs(self._local_checkpoint_dir, exist_ok=True)
- self._remote_checkpoint_dir = remote_checkpoint_dir
+ self._experiment_dir_name = experiment_dir_name
- self._syncer = get_node_to_storage_syncer(sync_config)
+ self._syncer = get_node_to_storage_syncer(self._sync_config)
self._stopper = stopper or NoopStopper()
self._resumed = False
@@ -562,14 +504,13 @@ def end_experiment_callbacks(self) -> None:
def _create_checkpoint_manager(self, sync_trial_checkpoints: bool = True):
return _ExperimentCheckpointManager(
- checkpoint_dir=self._local_checkpoint_dir,
+ local_checkpoint_dir=self._local_checkpoint_dir,
checkpoint_period=self._checkpoint_period,
start_time=self._start_time,
session_str=self._session_str,
syncer=self._syncer,
sync_trial_checkpoints=sync_trial_checkpoints,
- local_dir=self._local_checkpoint_dir,
- remote_dir=self._remote_checkpoint_dir,
+ remote_checkpoint_dir=self._remote_checkpoint_dir,
sync_every_n_trial_checkpoints=self._trial_checkpoint_config.num_to_keep,
)
@@ -585,6 +526,12 @@ def search_alg(self):
def scheduler_alg(self):
return self._scheduler_alg
+ @property
+ def _remote_checkpoint_dir(self):
+ if self._sync_config.upload_dir and self._experiment_dir_name:
+ return os.path.join(self._sync_config.upload_dir, self._experiment_dir_name)
+ return None
+
def _validate_resume(
self, resume_type: Union[str, bool], driver_sync_trial_checkpoints=True
) -> Tuple[bool, Optional[_ResumeConfig]]:
@@ -845,19 +792,34 @@ def resume(
)
)
- trial_runner_data = runner_state["runner_data"]
- # Don't overwrite the current `_local_checkpoint_dir`
- # The current directory could be different from the checkpointed
- # directory, if the experiment directory has changed.
- trial_runner_data.pop("_local_checkpoint_dir", None)
+ # 1. Restore trial runner state
+ self.__setstate__(runner_state["runner_data"])
- self.__setstate__(trial_runner_data)
+ # 2. Restore search algorithm state
if self._search_alg.has_checkpoint(self._local_checkpoint_dir):
self._search_alg.restore_from_dir(self._local_checkpoint_dir)
- trials = _load_trials_from_experiment_checkpoint(
- runner_state, new_local_dir=self._local_checkpoint_dir
- )
+ # 3. Load trial table from experiment checkpoint
+ trials = []
+ for trial_json_state in runner_state["checkpoints"]:
+ trial = Trial.from_json_state(trial_json_state)
+
+ # The following properties may be updated on restoration
+ # Ex: moved local/cloud experiment directory
+ trial.local_dir = self._local_checkpoint_dir
+ trial.sync_config = self._sync_config
+ trial.experiment_dir_name = self._experiment_dir_name
+
+ # Avoid creating logdir in client mode for returned trial results,
+ # since the dir might not be creatable locally.
+ # TODO(ekl) this is kind of a hack.
+ if not ray.util.client.ray.is_connected():
+ trial.init_logdir() # Create logdir if it does not exist
+
+ trial.refresh_default_resource_request()
+ trials.append(trial)
+
+ # 4. Set trial statuses according to the resume configuration
for trial in sorted(trials, key=lambda t: t.last_update_time, reverse=True):
trial_to_add = trial
if trial.status == Trial.ERROR:
@@ -1623,6 +1585,9 @@ def __getstate__(self):
"_syncer",
"_callbacks",
"_checkpoint_manager",
+ "_local_checkpoint_dir",
+ "_sync_config",
+ "_experiment_dir_name",
]:
del state[k]
state["launch_web_server"] = bool(self._server)
diff --git a/python/ray/tune/experiment/trial.py b/python/ray/tune/experiment/trial.py
index fa71ae26f4b41..88f37b8a0a48f 100644
--- a/python/ray/tune/experiment/trial.py
+++ b/python/ray/tune/experiment/trial.py
@@ -40,7 +40,7 @@
PlacementGroupFactory,
resource_dict_to_pg_factory,
)
-from ray.tune.utils.serialization import TuneFunctionEncoder
+from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder
from ray.tune.trainable.util import TrainableUtil
from ray.tune.utils import date_str, flatten_dict
from ray.util.annotations import DeveloperAPI
@@ -293,7 +293,7 @@ def __init__(
self.trainable_name = trainable_name
self.trial_id = Trial.generate_id() if trial_id is None else trial_id
self.config = config or {}
- self.local_dir = local_dir # This remains unexpanded for syncing.
+ self._local_dir = local_dir # This remains unexpanded for syncing.
# Parameters that Tune varies across searches.
self.evaluated_params = evaluated_params or {}
@@ -472,9 +472,39 @@ def get_runner_ip(self) -> Optional[str]:
self.location = _Location(hostname, pid)
return self.location.hostname
+ @property
+ def local_dir(self):
+ return self._local_dir
+
+ @local_dir.setter
+ def local_dir(self, local_dir):
+ relative_checkpoint_dirs = []
+ if self.logdir:
+ # Save the relative paths of persistent trial checkpoints, which are saved
+ # relative to the old `local_dir`/`logdir`
+ for checkpoint in self.get_trial_checkpoints():
+ checkpoint_dir = checkpoint.dir_or_data
+ assert isinstance(checkpoint_dir, str)
+ relative_checkpoint_dirs.append(
+ os.path.relpath(checkpoint_dir, self.logdir)
+ )
+
+ # Update the underlying `_local_dir`, which also updates the trial `logdir`
+ self._local_dir = local_dir
+
+ if self.logdir:
+ for checkpoint, relative_checkpoint_dir in zip(
+ self.get_trial_checkpoints(), relative_checkpoint_dirs
+ ):
+ # Reconstruct the checkpoint dir using the (possibly updated)
+ # trial logdir and the relative checkpoint directory.
+ checkpoint.dir_or_data = os.path.join(
+ self.logdir, relative_checkpoint_dir
+ )
+
@property
def logdir(self):
- if not self.relative_logdir:
+ if not self.local_dir or not self.relative_logdir:
return None
return str(Path(self.local_dir).joinpath(self.relative_logdir))
@@ -901,6 +931,20 @@ def get_json_state(self) -> str:
self._state_valid = True
return self._state_json
+ @classmethod
+ def from_json_state(cls, json_state: str, stub: bool = False) -> "Trial":
+ trial_state = json.loads(json_state, cls=TuneFunctionDecoder)
+
+ new_trial = Trial(
+ trial_state["trainable_name"],
+ stub=stub,
+ _setup_default_resource=False,
+ )
+
+ new_trial.__setstate__(trial_state)
+
+ return new_trial
+
def __getstate__(self):
"""Memento generator for Trial.
@@ -922,53 +966,21 @@ def __getstate__(self):
state["_state_valid"] = False
state["_default_result_or_future"] = None
- # Save the relative paths of persistent trial checkpoints
- # When loading this trial state, the paths should be constructed again
- # relative to the trial `logdir`, which may have been updated.
- relative_checkpoint_dirs = []
- for checkpoint in self.get_trial_checkpoints():
- checkpoint_dir = checkpoint.dir_or_data
- assert isinstance(checkpoint_dir, str)
- relative_checkpoint_dirs.append(
- os.path.relpath(checkpoint_dir, self.logdir)
- )
- state["__relative_checkpoint_dirs"] = relative_checkpoint_dirs
-
return copy.deepcopy(state)
def __setstate__(self, state):
-
if state["status"] == Trial.RUNNING:
state["status"] = Trial.PENDING
for key in self._nonjson_fields:
if key in state:
state[key] = cloudpickle.loads(hex_to_binary(state[key]))
- # Retrieve the relative checkpoint dirs
- relative_checkpoint_dirs = state.pop("__relative_checkpoint_dirs", None)
-
# Ensure that stub doesn't get overriden
stub = state.pop("stub", True)
self.__dict__.update(state)
self.stub = stub or getattr(self, "stub", False)
- if relative_checkpoint_dirs:
- for checkpoint, relative_checkpoint_dir in zip(
- self.get_trial_checkpoints(), relative_checkpoint_dirs
- ):
- # Reconstruct the checkpoint dir using the (possibly updated)
- # trial logdir and the relative checkpoint directory.
- checkpoint.dir_or_data = os.path.join(
- self.logdir, relative_checkpoint_dir
- )
-
if not self.stub:
validate_trainable(self.trainable_name)
assert self.placement_group_factory
-
- # Avoid creating logdir in client mode for returned trial results,
- # since the dir might not be creatable locally.
- # TODO(ekl) this is kind of a hack.
- if not ray.util.client.ray.is_connected():
- self.init_logdir() # Create logdir if it does not exist
diff --git a/python/ray/tune/impl/tuner_internal.py b/python/ray/tune/impl/tuner_internal.py
index b47f3321e4746..38dd889d8912b 100644
--- a/python/ray/tune/impl/tuner_internal.py
+++ b/python/ray/tune/impl/tuner_internal.py
@@ -7,6 +7,7 @@
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, Union, TYPE_CHECKING, Tuple
+import urllib.parse
import ray
import ray.cloudpickle as pickle
@@ -334,6 +335,16 @@ def _restore_from_path_or_uri(
self._run_config.local_dir = str(experiment_path.parent)
self._run_config.name = experiment_path.name
else:
+ # Set the experiment `name` and `upload_dir` according to the URI
+ parsed_uri = urllib.parse.urlparse(path_or_uri)
+ remote_path = Path(os.path.normpath(parsed_uri.netloc + parsed_uri.path))
+ upload_dir = parsed_uri._replace(
+ netloc="", path=str(remote_path.parent)
+ ).geturl()
+
+ self._run_config.name = remote_path.name
+ self._run_config.sync_config.upload_dir = upload_dir
+
# If we synced, `experiment_checkpoint_dir` will contain a temporary
# directory. Create an experiment checkpoint dir instead and move
# our data there.
diff --git a/python/ray/tune/search/nevergrad/nevergrad_search.py b/python/ray/tune/search/nevergrad/nevergrad_search.py
index 0a8b6ba9b8722..5b57069373e54 100644
--- a/python/ray/tune/search/nevergrad/nevergrad_search.py
+++ b/python/ray/tune/search/nevergrad/nevergrad_search.py
@@ -50,8 +50,12 @@ class NevergradSearch(Searcher):
$ pip install nevergrad
Parameters:
- optimizer: Optimizer provided
- from Nevergrad. Alter
+ optimizer: Optimizer class provided from Nevergrad.
+ See here for available optimizers:
+ https://facebookresearch.github.io/nevergrad/optimizers_ref.html#optimizers
+ This can also be an instance of a `ConfiguredOptimizer`. See the
+ section on configured optimizers in the above link.
+ optimizer_kwargs: Kwargs passed in when instantiating the `optimizer`
space: Nevergrad parametrization
to be passed to optimizer on instantiation, or list of parameter
names if you passed an optimizer object.
@@ -120,11 +124,11 @@ def __init__(
optimizer: Optional[
Union[Optimizer, Type[Optimizer], ConfiguredOptimizer]
] = None,
+ optimizer_kwargs: Optional[Dict] = None,
space: Optional[Union[Dict, Parameter]] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
points_to_evaluate: Optional[List[Dict]] = None,
- **kwargs,
):
assert (
ng is not None
@@ -134,11 +138,12 @@ def __init__(
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
- super(NevergradSearch, self).__init__(metric=metric, mode=mode, **kwargs)
+ super(NevergradSearch, self).__init__(metric=metric, mode=mode)
self._space = None
self._opt_factory = None
self._nevergrad_opt = None
+ self._optimizer_kwargs = optimizer_kwargs or {}
if points_to_evaluate is None:
self._points_to_evaluate = None
@@ -166,6 +171,13 @@ def __init__(
"pass a list of parameter names or None as the `space` "
"parameter."
)
+ if self._optimizer_kwargs:
+ raise ValueError(
+ "If you pass in optimizer kwargs, either pass "
+ "an `Optimizer` subclass or an instance of "
+ "`ConfiguredOptimizer`."
+ )
+
self._parameters = space
self._nevergrad_opt = optimizer
elif (
@@ -187,7 +199,9 @@ def __init__(
def _setup_nevergrad(self):
if self._opt_factory:
- self._nevergrad_opt = self._opt_factory(self._space)
+ self._nevergrad_opt = self._opt_factory(
+ self._space, **self._optimizer_kwargs
+ )
# nevergrad.tell internally minimizes, so "max" => -1
if self._mode == "max":
diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py
index 13d2d1a4fca70..5467fb763602e 100644
--- a/python/ray/tune/tests/test_api.py
+++ b/python/ray/tune/tests/test_api.py
@@ -1806,55 +1806,22 @@ def train(config, reporter):
self.assertEqual(trial.last_result["mean_accuracy"], float("inf"))
def testSearcherSchedulerStr(self):
- def train(config):
- tune.report(metric=1)
-
capture = {}
class MockTrialRunner(TrialRunner):
- def __init__(
- self,
- search_alg=None,
- scheduler=None,
- local_checkpoint_dir=None,
- remote_checkpoint_dir=None,
- sync_config=None,
- stopper=None,
- resume=False,
- server_port=None,
- fail_fast=False,
- checkpoint_period=None,
- trial_executor=None,
- callbacks=None,
- metric=None,
- trial_checkpoint_config=None,
- driver_sync_trial_checkpoints=True,
- ):
- # should be converted from strings at this case
- # and not None
+ def __init__(self, search_alg=None, scheduler=None, **kwargs):
+ # should be converted from strings at this case and not None
capture["search_alg"] = search_alg
capture["scheduler"] = scheduler
super().__init__(
search_alg=search_alg,
scheduler=scheduler,
- local_checkpoint_dir=local_checkpoint_dir,
- remote_checkpoint_dir=remote_checkpoint_dir,
- sync_config=sync_config,
- stopper=stopper,
- resume=resume,
- server_port=server_port,
- fail_fast=fail_fast,
- checkpoint_period=checkpoint_period,
- trial_executor=trial_executor,
- callbacks=callbacks,
- metric=metric,
- trial_checkpoint_config=trial_checkpoint_config,
- driver_sync_trial_checkpoints=True,
+ **kwargs,
)
with patch("ray.tune.tune.TrialRunner", MockTrialRunner):
tune.run(
- train,
+ lambda config: tune.report(metric=1),
search_alg="random",
scheduler="async_hyperband",
metric="metric",
@@ -1889,42 +1856,14 @@ def train(config):
capture = {}
class MockTrialRunner(TrialRunner):
- def __init__(
- self,
- search_alg=None,
- scheduler=None,
- local_checkpoint_dir=None,
- remote_checkpoint_dir=None,
- sync_config=None,
- stopper=None,
- resume=False,
- server_port=None,
- fail_fast=False,
- checkpoint_period=None,
- trial_executor=None,
- callbacks=None,
- metric=None,
- trial_checkpoint_config=None,
- driver_sync_trial_checkpoints=True,
- ):
+ def __init__(self, search_alg=None, scheduler=None, **kwargs):
+ # should be converted from strings at this case and not None
capture["search_alg"] = search_alg
capture["scheduler"] = scheduler
super().__init__(
search_alg=search_alg,
scheduler=scheduler,
- local_checkpoint_dir=local_checkpoint_dir,
- remote_checkpoint_dir=remote_checkpoint_dir,
- sync_config=sync_config,
- stopper=stopper,
- resume=resume,
- server_port=server_port,
- fail_fast=fail_fast,
- checkpoint_period=checkpoint_period,
- trial_executor=trial_executor,
- callbacks=callbacks,
- metric=metric,
- trial_checkpoint_config=trial_checkpoint_config,
- driver_sync_trial_checkpoints=driver_sync_trial_checkpoints,
+ **kwargs,
)
with patch("ray.tune.tune.TrialRunner", MockTrialRunner):
diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py
index dc5e32ec9f339..c89202b50d2b4 100644
--- a/python/ray/tune/tests/test_experiment_analysis.py
+++ b/python/ray/tune/tests/test_experiment_analysis.py
@@ -91,14 +91,8 @@ def testStats(self):
def testTrialDataframe(self):
checkpoints = self.ea._checkpoints_and_paths
idx = random.randint(0, len(checkpoints) - 1)
- logdir_from_checkpoint = str(
- checkpoints[idx][1].joinpath(checkpoints[idx][0]["relative_logdir"])
- )
logdir_from_trial = self.ea.trials[idx].logdir
-
- self.assertEqual(logdir_from_checkpoint, logdir_from_trial)
-
- trial_df = self.ea.trial_dataframes[logdir_from_checkpoint]
+ trial_df = self.ea.trial_dataframes[logdir_from_trial]
self.assertTrue(isinstance(trial_df, pd.DataFrame))
self.assertEqual(trial_df.shape[0], 1)
diff --git a/python/ray/tune/tests/test_experiment_analysis_mem.py b/python/ray/tune/tests/test_experiment_analysis_mem.py
index 315ceaca1fec3..1033f7ff38aff 100644
--- a/python/ray/tune/tests/test_experiment_analysis_mem.py
+++ b/python/ray/tune/tests/test_experiment_analysis_mem.py
@@ -1,4 +1,3 @@
-import json
import unittest
import shutil
import tempfile
@@ -56,30 +55,6 @@ def load_checkpoint(self, checkpoint_path):
def tearDown(self):
shutil.rmtree(self.test_dir, ignore_errors=True)
- def testInitLegacy(self):
- """Should still work if checkpoints are not json strings"""
- experiment_checkpoint_path = os.path.join(
- self.test_dir, "experiment_state.json"
- )
- checkpoint_data = {
- "checkpoints": [
- {
- "trial_id": "abcd1234",
- "status": Trial.TERMINATED,
- "trainable_name": "MockTrainable",
- "local_dir": self.test_dir,
- "relative_logdir": "MockTrainable_0_id=3_2020-07-12",
- }
- ]
- }
-
- with open(experiment_checkpoint_path, "w") as f:
- f.write(json.dumps(checkpoint_data))
-
- experiment_analysis = ExperimentAnalysis(experiment_checkpoint_path)
- self.assertEqual(len(experiment_analysis._checkpoints_and_paths), 1)
- self.assertTrue(experiment_analysis.trials)
-
def testInit(self):
trial = Trial(
"MockTrainable", stub=True, trial_id="abcd1234", local_dir=self.test_dir
diff --git a/python/ray/tune/tests/test_searchers.py b/python/ray/tune/tests/test_searchers.py
index f74a1266f4127..1bf0521ef3f25 100644
--- a/python/ray/tune/tests/test_searchers.py
+++ b/python/ray/tune/tests/test_searchers.py
@@ -272,6 +272,12 @@ def testNevergrad(self):
)
self.assertCorrectExperimentOutput(out)
+ def testNevergradWithRequiredOptimizerKwargs(self):
+ from ray.tune.search.nevergrad import NevergradSearch
+ import nevergrad as ng
+
+ NevergradSearch(optimizer=ng.optimizers.CM, optimizer_kwargs=dict(budget=16))
+
def testOptuna(self):
from ray.tune.search.optuna import OptunaSearch
from optuna.samplers import RandomSampler
diff --git a/python/ray/tune/tests/test_trial_relative_logdir.py b/python/ray/tune/tests/test_trial_relative_logdir.py
index 41f7c6c466101..f2f388ec6858c 100644
--- a/python/ray/tune/tests/test_trial_relative_logdir.py
+++ b/python/ray/tune/tests/test_trial_relative_logdir.py
@@ -1,4 +1,3 @@
-import json
import os
import shutil
import sys
@@ -13,9 +12,7 @@
import ray
from ray import tune
from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
-from ray.tune.execution.trial_runner import _load_trial_from_checkpoint
from ray.tune.experiment import Trial
-from ray.tune.utils.serialization import TuneFunctionDecoder
def train(config):
@@ -259,8 +256,9 @@ def testRelativeLogdirWithJson(self):
def test_load_trial_from_json_state(tmpdir):
- """Check that `Trial.get_json_state` and `_load_trial_from_checkpoint`
- for saving and loading a Trial is done correctly."""
+ """Check that serializing a trial to a JSON string with `Trial.get_json_state`
+ and then creating a new trial using the `Trial.from_json_state` alternate
+ constructor loads the trial with equivalent state."""
trial = Trial(
"MockTrainable", stub=True, trial_id="abcd1234", local_dir=str(tmpdir)
)
@@ -276,22 +274,37 @@ def test_load_trial_from_json_state(tmpdir):
)
)
- json_cp = trial.get_json_state()
- trial_cp = json.loads(json_cp, cls=TuneFunctionDecoder)
# After loading, the trial state should be the same
- new_trial = _load_trial_from_checkpoint(trial_cp.copy(), stub=True)
- assert new_trial.get_json_state() == json_cp
+ json_state = trial.get_json_state()
+ new_trial = Trial.from_json_state(json_state, stub=True)
+ assert new_trial.get_json_state() == json_state
+
+
+def test_change_trial_local_dir(tmpdir):
+ trial = Trial(
+ "MockTrainable", stub=True, trial_id="abcd1234", local_dir=str(tmpdir)
+ )
+ trial.init_logdir()
+ trial.status = Trial.TERMINATED
+
+ checkpoint_logdir = os.path.join(trial.logdir, "checkpoint_00000")
+ trial.checkpoint_manager.on_checkpoint(
+ _TrackedCheckpoint(
+ dir_or_data=checkpoint_logdir,
+ storage_mode=CheckpointStorage.PERSISTENT,
+ metrics={"training_iteration": 1},
+ )
+ )
+
+ assert trial.logdir.startswith(str(tmpdir))
+ assert trial.get_trial_checkpoints()[0].dir_or_data.startswith(str(tmpdir))
# Specify a new local dir, and the logdir/checkpoint path should be updated
with tempfile.TemporaryDirectory() as new_local_dir:
- new_trial = _load_trial_from_checkpoint(
- trial_cp.copy(), stub=True, new_local_dir=new_local_dir
- )
+ trial.local_dir = new_local_dir
- assert new_trial.logdir.startswith(new_local_dir)
- assert new_trial.get_trial_checkpoints()[0].dir_or_data.startswith(
- new_local_dir
- )
+ assert trial.logdir.startswith(new_local_dir)
+ assert trial.get_trial_checkpoints()[0].dir_or_data.startswith(new_local_dir)
if __name__ == "__main__":
diff --git a/python/ray/tune/tests/test_trial_runner_3.py b/python/ray/tune/tests/test_trial_runner_3.py
index 9b0c10f6db4be..a50b714553de2 100644
--- a/python/ray/tune/tests/test_trial_runner_3.py
+++ b/python/ray/tune/tests/test_trial_runner_3.py
@@ -905,7 +905,6 @@ def delete(self, remote_dir: str) -> bool:
sync_config=SyncConfig(
upload_dir="fake", syncer=CustomSyncer(), sync_period=0
),
- remote_checkpoint_dir="fake",
trial_executor=RayTrialExecutor(resource_manager=self._resourceManager()),
)
runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 1}))
@@ -951,7 +950,6 @@ def delete(self, remote_dir: str) -> bool:
runner = TrialRunner(
local_checkpoint_dir=self.tmpdir,
sync_config=SyncConfig(upload_dir="fake", syncer=syncer),
- remote_checkpoint_dir="fake",
trial_checkpoint_config=checkpoint_config,
checkpoint_period=100, # Only rely on forced syncing
trial_executor=RayTrialExecutor(resource_manager=self._resourceManager()),
@@ -1019,7 +1017,6 @@ def testForcedCloudCheckpointSyncTimeout(self):
runner = TrialRunner(
local_checkpoint_dir=self.tmpdir,
sync_config=SyncConfig(upload_dir="fake", syncer=syncer),
- remote_checkpoint_dir="fake",
)
# Checkpoint for the first time starts the first sync in the background
runner.checkpoint(force=True)
@@ -1047,7 +1044,6 @@ def testPeriodicCloudCheckpointSyncTimeout(self):
runner = TrialRunner(
local_checkpoint_dir=self.tmpdir,
sync_config=SyncConfig(upload_dir="fake", syncer=syncer),
- remote_checkpoint_dir="fake",
)
with freeze_time() as frozen:
diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py
index 11a7aa37a670f..e77585c783134 100644
--- a/python/ray/tune/tests/test_trial_scheduler.py
+++ b/python/ray/tune/tests/test_trial_scheduler.py
@@ -855,6 +855,8 @@ def __init__(self, i, config):
self.resources = Resources(1, 0)
self.custom_trial_name = None
self.custom_dirname = None
+ self._local_dir = None
+ self.relative_logdir = None
self._default_result_or_future = None
self.checkpoint_manager = _CheckpointManager(
checkpoint_config=CheckpointConfig(
diff --git a/python/ray/tune/tests/test_tune_restore_warm_start.py b/python/ray/tune/tests/test_tune_restore_warm_start.py
index 4530cc6182748..368c4af097e09 100644
--- a/python/ray/tune/tests/test_tune_restore_warm_start.py
+++ b/python/ray/tune/tests/test_tune_restore_warm_start.py
@@ -264,7 +264,7 @@ def cost(space, reporter):
search_alg = NevergradSearch(
optimizer,
- parameter_names,
+ space=parameter_names,
metric="loss",
mode="min",
)
diff --git a/python/ray/tune/tests/test_tuner_restore.py b/python/ray/tune/tests/test_tuner_restore.py
index b528ca80590a2..edca293e1decc 100644
--- a/python/ray/tune/tests/test_tuner_restore.py
+++ b/python/ray/tune/tests/test_tuner_restore.py
@@ -18,7 +18,12 @@
ScalingConfig,
session,
)
-from ray.air._internal.remote_storage import delete_at_uri, download_from_uri
+from ray.air._internal.remote_storage import (
+ delete_at_uri,
+ download_from_uri,
+ upload_to_uri,
+ list_at_uri,
+)
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.tune import Callback, Trainable
from ray.tune.execution.trial_runner import _find_newest_experiment_checkpoint
@@ -54,6 +59,12 @@ def chdir_tmpdir(tmpdir):
os.chdir(old_cwd)
+@pytest.fixture
+def clear_memory_filesys():
+ yield
+ delete_at_uri("memory:///")
+
+
def _train_fn_sometimes_failing(config):
# Fails if failing is set and marker file exists.
# Hangs if hanging is set and marker file exists.
@@ -365,7 +376,7 @@ def test_tuner_resume_errored_only(ray_start_2_cpus, tmpdir):
assert sorted([r.metrics.get("it", 0) for r in results]) == sorted([2, 1, 3, 0])
-def test_tuner_restore_from_cloud(ray_start_2_cpus, tmpdir):
+def test_tuner_restore_from_cloud(ray_start_2_cpus, tmpdir, clear_memory_filesys):
"""Check that restoring Tuner() objects from cloud storage works"""
tuner = Tuner(
lambda config: 1,
@@ -419,7 +430,7 @@ def test_tuner_restore_from_cloud(ray_start_2_cpus, tmpdir):
[None, "memory:///test/test_tuner_restore_latest_available_checkpoint"],
)
def test_tuner_restore_latest_available_checkpoint(
- ray_start_4_cpus, tmpdir, upload_uri
+ ray_start_4_cpus, tmpdir, upload_uri, clear_memory_filesys
):
"""Resuming errored trials should pick up from previous state"""
fail_marker = tmpdir / "fail_marker"
@@ -746,6 +757,68 @@ def test_tuner_restore_from_moved_experiment_path(
assert not old_local_dir.exists()
+def test_tuner_restore_from_moved_cloud_uri(
+ ray_start_2_cpus, tmp_path, clear_memory_filesys
+):
+ """Test that restoring an experiment that was moved to a new remote URI
+ resumes and continues saving new results at that URI."""
+
+ def failing_fn(config):
+ data = {"score": 1}
+ session.report(data, checkpoint=Checkpoint.from_dict(data))
+ raise RuntimeError("Failing!")
+
+ tuner = Tuner(
+ failing_fn,
+ run_config=RunConfig(
+ name="exp_dir",
+ local_dir=str(tmp_path / "ray_results"),
+ sync_config=tune.SyncConfig(upload_dir="memory:///original"),
+ ),
+ tune_config=TuneConfig(trial_dirname_creator=lambda _: "test"),
+ )
+ tuner.fit()
+
+ # mv memory:///original/exp_dir memory:///moved/new_exp_dir
+ download_from_uri(
+ "memory:///original/exp_dir", str(tmp_path / "moved" / "new_exp_dir")
+ )
+ delete_at_uri("memory:///original")
+ upload_to_uri(str(tmp_path / "moved"), "memory:///moved")
+
+ tuner = Tuner.restore("memory:///moved/new_exp_dir", resume_errored=True)
+ # Just for the test, since we're using `memory://` to mock a remote filesystem,
+ # the checkpoint needs to be copied to the new local directory.
+ # This is because the trainable actor uploads its checkpoints to a
+ # different `memory://` filesystem than the driver and is not
+ # downloaded along with the other parts of the experiment dir.
+ # NOTE: A new local directory is used since the experiment name got modified.
+ shutil.move(
+ tmp_path / "ray_results/exp_dir/test/checkpoint_000000",
+ tmp_path / "ray_results/new_exp_dir/test/checkpoint_000000",
+ )
+ results = tuner.fit()
+
+ assert list_at_uri("memory:///") == ["moved"]
+ num_experiment_checkpoints = len(
+ [
+ path
+ for path in list_at_uri("memory:///moved/new_exp_dir")
+ if path.startswith("experiment_state")
+ ]
+ )
+ assert num_experiment_checkpoints == 2
+
+ num_trial_checkpoints = len(
+ [
+ path
+ for path in os.listdir(results[0].log_dir)
+ if path.startswith("checkpoint_")
+ ]
+ )
+ assert num_trial_checkpoints == 2
+
+
def test_restore_from_relative_path(ray_start_4_cpus, chdir_tmpdir):
tuner = Tuner(
lambda config: session.report({"score": 1}),
@@ -812,7 +885,7 @@ def on_trial_result(self, runner, trial, result):
@pytest.mark.parametrize("use_air_trainer", [True, False])
-def test_checkpoints_saved_after_resume(tmp_path, use_air_trainer):
+def test_checkpoints_saved_after_resume(ray_start_2_cpus, tmp_path, use_air_trainer):
"""Checkpoints saved after experiment restore should pick up at the correct
iteration and should not overwrite the checkpoints from the original run.
Old checkpoints should still be deleted if the total number of checkpoints
diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py
index 1e09b1e0fbb9b..a577999dc3fb3 100644
--- a/python/ray/tune/tune.py
+++ b/python/ray/tune/tune.py
@@ -696,7 +696,7 @@ class and registered trainables.
search_alg=search_alg,
scheduler=scheduler,
local_checkpoint_dir=experiments[0].checkpoint_dir,
- remote_checkpoint_dir=experiments[0].remote_checkpoint_dir,
+ experiment_dir_name=experiments[0].dir_name,
sync_config=sync_config,
stopper=experiments[0].stopper,
resume=resume,
diff --git a/python/ray/util/client/server/logservicer.py b/python/ray/util/client/server/logservicer.py
index 83351f64ca764..764e6c82c6534 100644
--- a/python/ray/util/client/server/logservicer.py
+++ b/python/ray/util/client/server/logservicer.py
@@ -54,9 +54,6 @@ def unregister_global(self):
def log_status_change_thread(log_queue, request_iterator):
- """This is run in a separate thread and therefore needs a separate logging
- configuration outside of the default ray logging configuration.
- """
std_handler = StdStreamHandler(log_queue)
current_handler = None
root_logger = logging.getLogger("ray")
diff --git a/python/ray/util/multiprocessing/pool.py b/python/ray/util/multiprocessing/pool.py
index b2fee2a9c4a44..fa659a5b4ae55 100644
--- a/python/ray/util/multiprocessing/pool.py
+++ b/python/ray/util/multiprocessing/pool.py
@@ -7,12 +7,14 @@
import queue
import sys
import threading
+import warnings
import time
from multiprocessing import TimeoutError
from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Tuple
import ray
from ray.util import log_once
+from ray.util.annotations import RayDeprecationWarning
try:
from joblib._parallel_backends import SafeFunction
@@ -391,7 +393,16 @@ def __init__(self, pool, func, iterable, chunksize=None):
try:
self._iterator = iter(iterable)
except TypeError:
- # for compatibility with prior releases, encapsulate non-iterable in a list
+ warnings.warn(
+ "Passing a non-iterable argument to the "
+ "ray.util.multiprocessing.Pool imap and imap_unordered "
+ "methods is deprecated as of Ray 2.3 and "
+ " will be removed in a future release. See "
+ "https://github.com/ray-project/ray/issues/24237 for more "
+ "information.",
+ category=RayDeprecationWarning,
+ stacklevel=3,
+ )
iterable = [iterable]
self._iterator = iter(iterable)
if isinstance(iterable, collections.abc.Iterator):
diff --git a/python/ray/util/spark/__init__.py b/python/ray/util/spark/__init__.py
index ddfd2de835bcb..eaa75205848c8 100644
--- a/python/ray/util/spark/__init__.py
+++ b/python/ray/util/spark/__init__.py
@@ -1,7 +1,7 @@
from ray.util.spark.cluster_init import (
- init_ray_cluster,
+ setup_ray_cluster,
shutdown_ray_cluster,
MAX_NUM_WORKER_NODES,
)
-__all__ = ["init_ray_cluster", "shutdown_ray_cluster", "MAX_NUM_WORKER_NODES"]
+__all__ = ["setup_ray_cluster", "shutdown_ray_cluster", "MAX_NUM_WORKER_NODES"]
diff --git a/python/ray/util/spark/cluster_init.py b/python/ray/util/spark/cluster_init.py
index 44949ff68c0e5..c8f8b347f4bc7 100644
--- a/python/ray/util/spark/cluster_init.py
+++ b/python/ray/util/spark/cluster_init.py
@@ -6,7 +6,7 @@
import logging
import uuid
from packaging.version import Version
-from typing import Optional, Dict
+from typing import Optional, Dict, Type
import ray
from ray.util.annotations import PublicAPI
@@ -48,7 +48,7 @@ def _check_system_environment():
try:
import pyspark
- if Version(pyspark.__version__) < Version("3.3"):
+ if Version(pyspark.__version__).release < (3, 3, 0):
raise RuntimeError(spark_dependency_error)
except ImportError:
raise RuntimeError(spark_dependency_error)
@@ -56,7 +56,7 @@ def _check_system_environment():
class RayClusterOnSpark:
"""
- This class is the type of instance returned by the `init_ray_cluster` API.
+ This class is the type of instance returned by the `_setup_ray_cluster` interface.
Its main functionality is to:
Connect to, disconnect from, and shutdown the Ray cluster running on Apache Spark.
Serve as a Python context manager for the `RayClusterOnSpark` instance.
@@ -77,6 +77,8 @@ def __init__(
num_workers_node,
temp_dir,
cluster_unique_id,
+ start_hook,
+ ray_dashboard_port,
):
self.address = address
self.head_proc = head_proc
@@ -84,8 +86,9 @@ def __init__(
self.num_worker_nodes = num_workers_node
self.temp_dir = temp_dir
self.cluster_unique_id = cluster_unique_id
+ self.start_hook = start_hook
+ self.ray_dashboard_port = ray_dashboard_port
- self.ray_context = None
self.is_shutdown = False
self.spark_job_is_canceled = False
self.background_job_exception = None
@@ -94,7 +97,7 @@ def _cancel_background_spark_job(self):
self.spark_job_is_canceled = True
get_spark_session().sparkContext.cancelJobGroup(self.spark_job_group_id)
- def connect(self):
+ def wait_until_ready(self):
import ray
if self.background_job_exception is not None:
@@ -106,14 +109,18 @@ def connect(self):
raise RuntimeError(
"The ray cluster has been shut down or it failed to start."
)
- if self.ray_context is None:
- try:
- # connect to the ray cluster.
- self.ray_context = ray.init(address=self.address)
- except Exception:
- self.shutdown()
- raise
+ try:
+ # connect to the ray cluster.
+ ray_ctx = ray.init(address=self.address)
+ webui_url = ray_ctx.address_info.get("webui_url", None)
+ if webui_url:
+ self.start_hook.on_ray_dashboard_created(self.ray_dashboard_port)
+ except Exception:
+ self.shutdown()
+ raise
+
+ try:
last_alive_worker_count = 0
last_progress_move_time = time.time()
while True:
@@ -145,33 +152,27 @@ def connect(self):
"failed to start."
)
return
- else:
- _logger.warning("Already connected to this ray cluster.")
+ finally:
+ ray.shutdown()
+
+ def connect(self):
+ if ray.is_initialized():
+ raise RuntimeError("Already connected to Ray cluster.")
+ ray.init(address=self.address)
def disconnect(self):
- if self.ray_context is not None:
- try:
- self.ray_context.disconnect()
- except Exception as e:
- # swallow exception.
- _logger.warning(
- f"An error occurred while disconnecting from the ray cluster: "
- f"{repr(e)}"
- )
- self.ray_context = None
- else:
- _logger.warning("Already disconnected from this ray cluster.")
+ ray.shutdown()
def shutdown(self, cancel_background_job=True):
"""
- Shutdown the ray cluster created by the `init_ray_cluster` API.
+ Shutdown the ray cluster created by the `setup_ray_cluster` API.
NB: In the background thread that runs the background spark job, if spark job
raise unexpected error, its exception handler will also call this method, in
the case, it will set cancel_background_job=False to avoid recursive call.
"""
if not self.is_shutdown:
- if self.ray_context is not None:
- self.disconnect()
+ self.disconnect()
+ os.environ.pop("RAY_ADDRESS", None)
if cancel_background_job:
try:
self._cancel_background_spark_job()
@@ -191,15 +192,18 @@ def shutdown(self, cancel_background_job=True):
self.is_shutdown = True
def __enter__(self):
- self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.shutdown()
+def _convert_ray_node_option_key(key):
+ return f"--{key.replace('_', '-')}"
+
+
def _convert_ray_node_options(options):
- return [f"--{k.replace('_', '-')}={str(v)}" for k, v in options.items()]
+ return [f"{_convert_ray_node_option_key(k)}={str(v)}" for k, v in options.items()]
_RAY_HEAD_STARTUP_TIMEOUT = 5
@@ -351,25 +355,32 @@ def hold_lock():
return worker_port_range_begin, worker_port_range_end
-def _init_ray_cluster(
- num_worker_nodes,
- object_store_memory_per_node=None,
- head_options=None,
- worker_options=None,
- ray_temp_root_dir=None,
- safe_mode=False,
- collect_log_to_path=None,
-):
+def _setup_ray_cluster(
+ *,
+ num_worker_nodes: int,
+ num_cpus_per_node: int,
+ num_gpus_per_node: int,
+ using_stage_scheduling: bool,
+ heap_memory_per_node: int,
+ object_store_memory_per_node: int,
+ head_node_options: Dict,
+ worker_node_options: Dict,
+ ray_temp_root_dir: str,
+ collect_log_to_path: str,
+) -> Type[RayClusterOnSpark]:
"""
- This function is used in testing, it has the same arguments with
- `ray.util.spark.init_ray_cluster` API, but it returns a `RayClusterOnSpark`
- instance instead.
+ The public API `ray.util.spark.setup_ray_cluster` does some argument
+ validation and then pass validated arguments to this interface.
+ and it returns a `RayClusterOnSpark` instance.
The returned instance can be used to connect to, disconnect from and shutdown the
ray cluster. This instance can also be used as a context manager (used by
- encapsulating operations within `with init_ray_cluster(...):`). Upon entering the
+ encapsulating operations within `with _setup_ray_cluster(...):`). Upon entering the
managed scope, the ray cluster is initiated and connected to. When exiting the
scope, the ray cluster is disconnected and shut down.
+
+ Note: This function interface is stable and can be used for
+ instrumentation logging patching.
"""
from pyspark.util import inheritable_thread_target
@@ -380,83 +391,40 @@ def _init_ray_cluster(
else:
start_hook = RayOnSparkStartHook()
- head_options = head_options or {}
- worker_options = worker_options or {}
-
spark = get_spark_session()
- # Environment configurations within the Spark Session that dictate how many cpus
- # and gpus to use for each submitted spark task.
- num_spark_task_cpus = int(spark.sparkContext.getConf().get("spark.task.cpus", "1"))
- num_spark_task_gpus = int(
- spark.sparkContext.getConf().get("spark.task.resource.gpu.amount", "0")
- )
-
- (
- ray_worker_node_heap_mem_bytes,
- ray_worker_node_object_store_mem_bytes,
- ) = get_avail_mem_per_ray_worker_node(spark, object_store_memory_per_node)
-
- max_concurrent_tasks = get_max_num_concurrent_tasks(spark.sparkContext)
- if num_worker_nodes == -1:
- # num_worker_nodes=-1 represents using all available spark task slots
- num_worker_nodes = max_concurrent_tasks
- elif num_worker_nodes <= 0:
- raise ValueError(
- "The value of 'num_worker_nodes' argument must be either a positive "
- "integer or 'ray.util.spark.MAX_NUM_WORKER_NODES'."
- )
-
- insufficient_resources = []
+ ray_head_ip = socket.gethostbyname(get_spark_application_driver_host(spark))
+ ray_head_port = get_random_unused_port(ray_head_ip, min_port=9000, max_port=10000)
- if num_spark_task_cpus < 4:
- insufficient_resources.append(
- "The provided CPU resources for each ray worker are inadequate to start "
- "a ray cluster. Based on the total cpu resources available and the "
- "configured task sizing, each ray worker would start with "
- f"{num_spark_task_cpus} CPU cores. This is less than the recommended "
- "value of `4` CPUs per worker. Increasing the spark configuration "
- "'spark.task.cpus' to a minimum of `4` addresses it."
- )
+ include_dashboard = head_node_options.pop("include_dashboard", None)
+ ray_dashboard_port = head_node_options.pop("dashboard_port", None)
- if ray_worker_node_heap_mem_bytes < 10 * 1024 * 1024 * 1024:
- insufficient_resources.append(
- "The provided memory resources for each ray worker are inadequate. Based "
- "on the total memory available on the spark cluster and the configured "
- "task sizing, each ray worker would start with "
- f"{ray_worker_node_heap_mem_bytes} bytes heap memory. This is less than "
- "the recommended value of 10GB. The ray worker node heap memory size is "
- "calculated by "
- "(SPARK_WORKER_NODE_PHYSICAL_MEMORY / num_local_spark_task_slots * 0.8) - "
- "object_store_memory_per_node. To increase the heap space available, "
- "increase the memory in the spark cluster by changing instance types or "
- "worker count, reduce the target `num_worker_nodes`, or apply a lower "
- "`object_store_memory_per_node`."
- )
- if insufficient_resources:
- if safe_mode:
- raise ValueError(
- "You are creating ray cluster on spark with safe mode (it can be "
- "disabled by setting argument 'safe_mode=False' when calling API "
- "'init_ray_cluster'), safe mode requires the spark cluster config "
- "satisfying following criterion: "
- "\n".join(insufficient_resources)
+ if include_dashboard is None or include_dashboard is True:
+ if ray_dashboard_port is None:
+ ray_dashboard_port = get_random_unused_port(
+ ray_head_ip, min_port=9000, max_port=10000, exclude_list=[ray_head_port]
)
- else:
- _logger.warning("\n".join(insufficient_resources))
-
- ray_head_ip = socket.gethostbyname(get_spark_application_driver_host(spark))
+ ray_dashboard_agent_port = get_random_unused_port(
+ ray_head_ip,
+ min_port=9000,
+ max_port=10000,
+ exclude_list=[ray_head_port, ray_dashboard_port],
+ )
- ray_head_port = get_random_unused_port(ray_head_ip, min_port=9000, max_port=10000)
- ray_dashboard_port = get_random_unused_port(
- ray_head_ip, min_port=9000, max_port=10000, exclude_list=[ray_head_port]
- )
- ray_dashboard_agent_port = get_random_unused_port(
- ray_head_ip,
- min_port=9000,
- max_port=10000,
- exclude_list=[ray_head_port, ray_dashboard_port],
- )
+ dashboard_options = [
+ "--dashboard-host=0.0.0.0",
+ f"--dashboard-port={ray_dashboard_port}",
+ f"--dashboard-agent-listen-port={ray_dashboard_agent_port}",
+ ]
+ # If include_dashboard is None, we don't set `--include-dashboard` option,
+ # in this case Ray will decide whether dashboard can be started
+ # (e.g. checking any missing dependencies).
+ if include_dashboard is True:
+ dashboard_options += ["--include-dashboard=true"]
+ else:
+ dashboard_options = [
+ "--include-dashboard=false",
+ ]
_logger.info(f"Ray head hostname {ray_head_ip}, port {ray_head_port}")
@@ -478,20 +446,18 @@ def _init_ray_cluster(
"--head",
f"--node-ip-address={ray_head_ip}",
f"--port={ray_head_port}",
- "--include-dashboard=true",
- "--dashboard-host=0.0.0.0",
- f"--dashboard-port={ray_dashboard_port}",
- f"--dashboard-agent-listen-port={ray_dashboard_agent_port}",
- # disallow ray tasks with cpu requirements from being scheduled on the head
+ # disallow ray tasks with cpu/gpu requirements from being scheduled on the head
# node.
"--num-cpus=0",
+ "--num-gpus=0",
# limit the memory allocation to the head node (actual usage may increase
# beyond this for processing of tasks and actors).
f"--memory={128 * 1024 * 1024}",
# limit the object store memory allocation to the head node (actual usage
# may increase beyond this for processing of tasks and actors).
f"--object-store-memory={128 * 1024 * 1024}",
- *_convert_ray_node_options(head_options),
+ *dashboard_options,
+ *_convert_ray_node_options(head_node_options),
]
_logger.info(f"Starting Ray head, command: {' '.join(ray_head_node_cmd)}")
@@ -522,8 +488,6 @@ def _init_ray_cluster(
_logger.info("Ray head node started.")
- start_hook.on_ray_dashboard_created(ray_dashboard_port)
-
# NB:
# In order to start ray worker nodes on spark cluster worker machines,
# We launch a background spark job:
@@ -571,22 +535,22 @@ def ray_cluster_job_mapper(_):
"-m",
"ray.util.spark.start_ray_node",
f"--temp-dir={ray_temp_dir}",
- f"--num-cpus={num_spark_task_cpus}",
+ f"--num-cpus={num_cpus_per_node}",
"--block",
f"--address={ray_head_ip}:{ray_head_port}",
- f"--memory={ray_worker_node_heap_mem_bytes}",
- f"--object-store-memory={ray_worker_node_object_store_mem_bytes}",
+ f"--memory={heap_memory_per_node}",
+ f"--object-store-memory={object_store_memory_per_node}",
f"--min-worker-port={worker_port_range_begin}",
f"--max-worker-port={worker_port_range_end - 1}",
f"--dashboard-agent-listen-port={ray_worker_node_dashboard_agent_port}",
- *_convert_ray_node_options(worker_options),
+ *_convert_ray_node_options(worker_node_options),
]
ray_worker_node_extra_envs = {
RAY_ON_SPARK_COLLECT_LOG_TO_PATH: collect_log_to_path or ""
}
- if num_spark_task_gpus > 0:
+ if num_gpus_per_node > 0:
task_resources = context.resources()
if "gpu" not in task_resources:
@@ -633,13 +597,19 @@ def ray_cluster_job_mapper(_):
spark_job_group_id = f"ray-cluster-{ray_head_port}-{cluster_unique_id}"
+ cluster_address = f"{ray_head_ip}:{ray_head_port}"
+ # Set RAY_ADDRESS environment variable to the cluster address.
+ os.environ["RAY_ADDRESS"] = cluster_address
+
ray_cluster_handler = RayClusterOnSpark(
- address=f"{ray_head_ip}:{ray_head_port}",
+ address=cluster_address,
head_proc=ray_head_proc,
spark_job_group_id=spark_job_group_id,
num_workers_node=num_worker_nodes,
temp_dir=ray_temp_dir,
cluster_unique_id=cluster_unique_id,
+ start_hook=start_hook,
+ ray_dashboard_port=ray_dashboard_port,
)
def background_job_thread_fn():
@@ -670,9 +640,18 @@ def background_job_thread_fn():
# slots become available, it continues to launch tasks on new available
# slots, and user can see the ray cluster worker number increases when more
# slots available.
- spark.sparkContext.parallelize(
+ job_rdd = spark.sparkContext.parallelize(
list(range(num_worker_nodes)), num_worker_nodes
- ).mapPartitions(ray_cluster_job_mapper).collect()
+ )
+
+ if using_stage_scheduling:
+ resource_profile = _create_resource_profile(
+ num_cpus_per_node,
+ num_gpus_per_node,
+ )
+ job_rdd = job_rdd.withResources(resource_profile)
+
+ job_rdd.mapPartitions(ray_cluster_job_mapper).collect()
except Exception as e:
# NB:
# The background spark job is designed to running forever until it is
@@ -698,6 +677,9 @@ def background_job_thread_fn():
target=inheritable_thread_target(background_job_thread_fn), args=()
).start()
+ # Call hook immediately after spark job started.
+ start_hook.on_spark_background_job_created(spark_job_group_id)
+
# wait background spark task starting.
for _ in range(_BACKGROUND_JOB_STARTUP_WAIT):
time.sleep(1)
@@ -706,7 +688,6 @@ def background_job_thread_fn():
"Ray workers failed to start."
) from ray_cluster_handler.background_job_exception
- start_hook.on_spark_background_job_created(spark_job_group_id)
return ray_cluster_handler
except Exception:
# If driver side setup ray-cluster routine raises exception, it might result
@@ -720,54 +701,138 @@ def background_job_thread_fn():
_active_ray_cluster = None
+def _create_resource_profile(num_cpus_per_node, num_gpus_per_node):
+ from pyspark.resource.profile import ResourceProfileBuilder
+ from pyspark.resource.requests import TaskResourceRequests
+
+ task_res_req = TaskResourceRequests().cpus(num_cpus_per_node)
+ if num_gpus_per_node > 0:
+ task_res_req = task_res_req.resource("gpu", num_gpus_per_node)
+ return ResourceProfileBuilder().require(task_res_req).build
+
+
+# A dict storing blocked key to replacement argument you should use.
+_head_node_option_block_keys = {
+ "temp_dir": "ray_temp_root_dir",
+ "block": None,
+ "head": None,
+ "node_ip_address": None,
+ "port": None,
+ "num_cpus": None,
+ "num_gpus": None,
+ "memory": None,
+ "object_store_memory": None,
+ "dashboard_host": None,
+ "dashboard_agent_listen_port": None,
+}
+
+_worker_node_option_block_keys = {
+ "temp_dir": "ray_temp_root_dir",
+ "block": None,
+ "head": None,
+ "address": None,
+ "num_cpus": "num_cpus_per_node",
+ "num_gpus": "num_gpus_per_node",
+ "memory": None,
+ "object_store_memory": "object_store_memory_per_node",
+ "dashboard_agent_listen_port": None,
+ "min_worker_port": None,
+ "max_worker_port": None,
+}
+
+
+def _verify_node_options(node_options, block_keys, node_type):
+ for key in node_options:
+ if key.startswith("--") or "-" in key:
+ raise ValueError(
+ "For a ray node option like '--foo-bar', you should convert it to "
+ "following format 'foo_bar' in 'head_node_options' / "
+ "'worker_node_options' arguments."
+ )
+
+ if key in block_keys:
+ common_err_msg = (
+ f"Setting option {_convert_ray_node_options(key)} for {node_type} "
+ "is not allowed."
+ )
+ replacement_arg = block_keys[key]
+ if replacement_arg:
+ raise ValueError(
+ f"{common_err_msg} You should set '{replacement_arg}' argument "
+ "instead."
+ )
+ else:
+ raise ValueError(
+ f"{common_err_msg} The option is controlled by Ray on Spark "
+ "routine."
+ )
+
+
@PublicAPI(stability="alpha")
-def init_ray_cluster(
+def setup_ray_cluster(
num_worker_nodes: int,
+ num_cpus_per_node: Optional[int] = None,
+ num_gpus_per_node: Optional[int] = None,
object_store_memory_per_node: Optional[int] = None,
- head_options: Optional[Dict] = None,
- worker_options: Optional[Dict] = None,
+ head_node_options: Optional[Dict] = None,
+ worker_node_options: Optional[Dict] = None,
ray_temp_root_dir: Optional[str] = None,
- safe_mode: Optional[bool] = False,
+ strict_mode: bool = False,
collect_log_to_path: Optional[str] = None,
) -> str:
"""
- Initialize a ray cluster on the spark cluster by starting a ray head node in the
+ Set up a ray cluster on the spark cluster by starting a ray head node in the
spark application's driver side node.
After creating the head node, a background spark job is created that
generates an instance of `RayClusterOnSpark` that contains configuration for the
ray cluster that will run on the Spark cluster's worker nodes.
- After a ray cluster initialized, your python process automatically connect to the
- ray cluster, you can call `ray.util.spark.shutdown_ray_cluster` to shut down the
- ray cluster.
+ After a ray cluster is set up, "RAY_ADDRESS" environment variable is set to
+ the cluster address, so you can call `ray.init()` without specifying ray cluster
+ address to connect to the cluster. To shut down the cluster you can call
+ `ray.util.spark.shutdown_ray_cluster()`.
Note: If the active ray cluster haven't shut down, you cannot create a new ray
cluster.
- Args
- num_worker_nodes: The number of spark worker nodes that the spark job will be
- submitted to. This argument represents how many concurrent spark tasks will
- be available in the creation of the ray cluster. The ray cluster's total
- available resources (memory, CPU and/or GPU) is equal to the quantity of
- resources allocated within these spark tasks.
- Specifying the `num_worker_nodes` as `-1` represents a ray cluster
- configuration that will use all available spark tasks slots (and resources
- allocated to the spark application) on the spark cluster.
- To create a spark cluster that is intended to be used exclusively as a
+ Args:
+ num_worker_nodes: This argument represents how many ray worker nodes to start
+ for the ray cluster.
+ Specifying the `num_worker_nodes` as `ray.util.spark.MAX_NUM_WORKER_NODES`
+ represents a ray cluster
+ configuration that will use all available resources configured for the
+ spark application.
+ To create a spark application that is intended to exclusively run a
shared ray cluster, it is recommended to set this argument to
- `ray.spark.utils.MAX_NUM_WORKER_NODES`.
+ `ray.util.spark.MAX_NUM_WORKER_NODES`.
+ num_cpus_per_node: Number of cpus available to per-ray worker node, if not
+ provided, use spark application configuration 'spark.task.cpus' instead.
+ **Limitation** Only spark version >= 3.4 or Databricks Runtime 12.x
+ supports setting this argument.
+ num_gpus_per_node: Number of gpus available to per-ray worker node, if not
+ provided, use spark application configuration
+ 'spark.task.resource.gpu.amount' instead.
+ This argument is only available on spark cluster that is configured with
+ 'gpu' resources.
+ **Limitation** Only spark version >= 3.4 or Databricks Runtime 12.x
+ supports setting this argument.
object_store_memory_per_node: Object store memory available to per-ray worker
node, but it is capped by
"dev_shm_available_size * 0.8 / num_tasks_per_spark_worker".
The default value equals to
- "dev_shm_available_size * 0.8 / num_tasks_per_spark_worker".
- head_options: A dict representing Ray head node options.
- worker_options: A dict representing Ray worker node options.
+ "0.3 * spark_worker_physical_memory * 0.8 / num_tasks_per_spark_worker".
+ head_node_options: A dict representing Ray head node extra options, these
+ options will be passed to `ray start` script. Note you need to convert
+ `ray start` options key from `--foo-bar` format to `foo_bar` format.
+ worker_node_options: A dict representing Ray worker node extra options,
+ these options will be passed to `ray start` script. Note you need to
+ convert `ray start` options key from `--foo-bar` format to `foo_bar`
+ format.
ray_temp_root_dir: A local disk path to store the ray temporary data. The
created cluster will create a subdirectory
"ray-{head_port}-{random_suffix}" beneath this path.
- safe_mode: Boolean flag to fast-fail initialization of the ray cluster if
+ strict_mode: Boolean flag to fast-fail initialization of the ray cluster if
the available spark cluster does not have sufficient resources to fulfill
the resource allocation for memory, cpu and gpu. When set to true, if the
- requested resources are not available for minimum recommended
+ requested resources are not available for recommended minimum recommended
functionality, an exception will be raised that details the inadequate
spark cluster configuration settings. If overridden as `False`,
a warning is raised.
@@ -784,6 +849,20 @@ def init_ray_cluster(
_check_system_environment()
+ head_node_options = head_node_options or {}
+ worker_node_options = worker_node_options or {}
+
+ _verify_node_options(
+ head_node_options,
+ _head_node_option_block_keys,
+ "Ray head node on spark",
+ )
+ _verify_node_options(
+ worker_node_options,
+ _worker_node_option_block_keys,
+ "Ray worker node on spark",
+ )
+
if _active_ray_cluster is not None:
raise RuntimeError(
"Current active ray cluster on spark haven't shut down. Please call "
@@ -797,16 +876,150 @@ def init_ray_cluster(
"by `ray.shutdown()` before initiating a Ray cluster on spark."
)
- cluster = _init_ray_cluster(
+ spark = get_spark_session()
+
+ spark_master = spark.sparkContext.master
+ if not (
+ spark_master.startswith("spark://") or spark_master.startswith("local-cluster[")
+ ):
+ raise RuntimeError(
+ "Ray on Spark only supports spark cluster in standalone mode or "
+ "local-cluster mode"
+ )
+
+ if (
+ is_in_databricks_runtime()
+ and Version(os.environ["DATABRICKS_RUNTIME_VERSION"]).major >= 12
+ ):
+ support_stage_scheduling = True
+ else:
+ import pyspark
+
+ if Version(pyspark.__version__).release >= (3, 4, 0):
+ support_stage_scheduling = True
+ else:
+ support_stage_scheduling = False
+
+ # Environment configurations within the Spark Session that dictate how many cpus
+ # and gpus to use for each submitted spark task.
+ num_spark_task_cpus = int(spark.sparkContext.getConf().get("spark.task.cpus", "1"))
+
+ if num_cpus_per_node is not None and num_cpus_per_node <= 0:
+ raise ValueError("Argument `num_cpus_per_node` value must be > 0.")
+
+ num_spark_task_gpus = int(
+ spark.sparkContext.getConf().get("spark.task.resource.gpu.amount", "0")
+ )
+
+ if num_gpus_per_node is not None and num_spark_task_gpus == 0:
+ raise ValueError(
+ "The spark cluster is not configured with 'gpu' resources, so that "
+ "you cannot specify the `num_gpus_per_node` argument."
+ )
+
+ if num_gpus_per_node is not None and num_gpus_per_node < 0:
+ raise ValueError("Argument `num_gpus_per_node` value must be >= 0.")
+
+ if num_cpus_per_node is not None or num_gpus_per_node is not None:
+ if support_stage_scheduling:
+ num_cpus_per_node = num_cpus_per_node or num_spark_task_cpus
+ num_gpus_per_node = num_gpus_per_node or num_spark_task_gpus
+
+ using_stage_scheduling = True
+ res_profile = _create_resource_profile(num_cpus_per_node, num_gpus_per_node)
+ else:
+ raise ValueError(
+ "Current spark version does not support stage scheduling, so that "
+ "you cannot set the argument `num_cpus_per_node` and "
+ "`num_gpus_per_node` values. Without setting the 2 arguments, "
+ "per-Ray worker node will be assigned with number of "
+ f"'spark.task.cpus' (equals to {num_spark_task_cpus}) cpu cores "
+ "and number of 'spark.task.resource.gpu.amount' "
+ f"(equals to {num_spark_task_gpus}) GPUs. To enable spark stage "
+ "scheduling, you need to upgrade spark to 3.4 version or use "
+ "Databricks Runtime 12.x."
+ )
+ else:
+ using_stage_scheduling = False
+ res_profile = None
+
+ num_cpus_per_node = num_spark_task_cpus
+ num_gpus_per_node = num_spark_task_gpus
+
+ (
+ ray_worker_node_heap_mem_bytes,
+ ray_worker_node_object_store_mem_bytes,
+ ) = get_avail_mem_per_ray_worker_node(
+ spark,
+ object_store_memory_per_node,
+ num_cpus_per_node,
+ num_gpus_per_node,
+ )
+
+ if num_worker_nodes == MAX_NUM_WORKER_NODES:
+ # num_worker_nodes=MAX_NUM_WORKER_NODES represents using all available
+ # spark task slots
+ num_worker_nodes = get_max_num_concurrent_tasks(spark.sparkContext, res_profile)
+ elif num_worker_nodes <= 0:
+ raise ValueError(
+ "The value of 'num_worker_nodes' argument must be either a positive "
+ "integer or 'ray.util.spark.MAX_NUM_WORKER_NODES'."
+ )
+
+ insufficient_resources = []
+
+ if num_cpus_per_node < 4:
+ insufficient_resources.append(
+ "The provided CPU resources for each ray worker are inadequate to start "
+ "a ray cluster. Based on the total cpu resources available and the "
+ "configured task sizing, each ray worker node would start with "
+ f"{num_cpus_per_node} CPU cores. This is less than the recommended "
+ "value of `4` CPUs per worker. On spark version >= 3.4 or Databricks "
+ "Runtime 12.x, you can set the argument `num_cpus_per_node` to "
+ "a value >= 4 to address it, otherwise you need to increase the spark "
+ "application configuration 'spark.task.cpus' to a minimum of `4` to "
+ "address it."
+ )
+
+ if ray_worker_node_heap_mem_bytes < 10 * 1024 * 1024 * 1024:
+ insufficient_resources.append(
+ "The provided memory resources for each ray worker node are inadequate. "
+ "Based on the total memory available on the spark cluster and the "
+ "configured task sizing, each ray worker would start with "
+ f"{ray_worker_node_heap_mem_bytes} bytes heap memory. This is less than "
+ "the recommended value of 10GB. The ray worker node heap memory size is "
+ "calculated by "
+ "(SPARK_WORKER_NODE_PHYSICAL_MEMORY / num_local_spark_task_slots * 0.8) - "
+ "object_store_memory_per_node. To increase the heap space available, "
+ "increase the memory in the spark cluster by changing instance types or "
+ "worker count, reduce the target `num_worker_nodes`, or apply a lower "
+ "`object_store_memory_per_node`."
+ )
+ if insufficient_resources:
+ if strict_mode:
+ raise ValueError(
+ "You are creating ray cluster on spark with strict mode (it can be "
+ "disabled by setting argument 'strict_mode=False' when calling API "
+ "'setup_ray_cluster'), strict mode requires the spark cluster config "
+ "satisfying following criterion: "
+ "\n".join(insufficient_resources)
+ )
+ else:
+ _logger.warning("\n".join(insufficient_resources))
+
+ cluster = _setup_ray_cluster(
num_worker_nodes=num_worker_nodes,
- object_store_memory_per_node=object_store_memory_per_node,
- head_options=head_options,
- worker_options=worker_options,
+ num_cpus_per_node=num_cpus_per_node,
+ num_gpus_per_node=num_gpus_per_node,
+ using_stage_scheduling=using_stage_scheduling,
+ heap_memory_per_node=ray_worker_node_heap_mem_bytes,
+ object_store_memory_per_node=ray_worker_node_object_store_mem_bytes,
+ head_node_options=head_node_options,
+ worker_node_options=worker_node_options,
ray_temp_root_dir=ray_temp_root_dir,
- safe_mode=safe_mode,
collect_log_to_path=collect_log_to_path,
)
- cluster.connect() # NB: this line might raise error.
+ cluster.wait_until_ready() # NB: this line might raise error.
# If connect cluster successfully, set global _active_ray_cluster to be the started
# cluster.
diff --git a/python/ray/util/spark/utils.py b/python/ray/util/spark/utils.py
index 56413abfa50b9..0169d45892c7f 100644
--- a/python/ray/util/spark/utils.py
+++ b/python/ray/util/spark/utils.py
@@ -7,9 +7,6 @@
import logging
-_MEMORY_BUFFER_OFFSET = 0.8
-
-
_logger = logging.getLogger("ray.util.spark.utils")
@@ -130,91 +127,210 @@ def get_spark_application_driver_host(spark):
return spark.conf.get("spark.driver.host")
-def get_max_num_concurrent_tasks(spark_context):
+def get_max_num_concurrent_tasks(spark_context, resource_profile):
"""Gets the current max number of concurrent tasks."""
- # pylint: disable=protected-access
- # spark version 3.1 and above have a different API for fetching max concurrent
- # tasks
- if spark_context._jsc.sc().version() >= "3.1":
- return spark_context._jsc.sc().maxNumConcurrentTasks(
- spark_context._jsc.sc().resourceProfileManager().resourceProfileFromId(0)
+ # pylint: disable=protected-access=
+ ssc = spark_context._jsc.sc()
+ if resource_profile is not None:
+
+ def dummpy_mapper(_):
+ pass
+
+ # Runs a dummy spark job to register the `res_profile`
+ spark_context.parallelize([1], 1).withResources(resource_profile).map(
+ dummpy_mapper
+ ).collect()
+
+ return ssc.maxNumConcurrentTasks(resource_profile._java_resource_profile)
+ else:
+ return ssc.maxNumConcurrentTasks(
+ ssc.resourceProfileManager().defaultResourceProfile()
)
- return spark_context._jsc.sc().maxNumConcurrentTasks()
def _get_total_physical_memory():
import psutil
+ if RAY_ON_SPARK_WORKER_PHYSICAL_MEMORY_BYTES in os.environ:
+ return int(os.environ[RAY_ON_SPARK_WORKER_PHYSICAL_MEMORY_BYTES])
return psutil.virtual_memory().total
def _get_total_shared_memory():
import shutil
+ if RAY_ON_SPARK_WORKER_SHARED_MEMORY_BYTES in os.environ:
+ return int(os.environ[RAY_ON_SPARK_WORKER_SHARED_MEMORY_BYTES])
+
return shutil.disk_usage("/dev/shm").total
-def _get_cpu_cores():
- import multiprocessing
+# The maximum proportion for Ray worker node object store memory size
+_RAY_ON_SPARK_MAX_OBJECT_STORE_MEMORY_PROPORTION = 0.8
- return multiprocessing.cpu_count()
+# The buffer offset for calculating Ray node memory.
+_RAY_ON_SPARK_WORKER_MEMORY_BUFFER_OFFSET = 0.8
def _calc_mem_per_ray_worker_node(
- num_task_slots, physical_mem_bytes, shared_mem_bytes, object_store_memory_per_node
+ num_task_slots, physical_mem_bytes, shared_mem_bytes, configured_object_store_bytes
):
+ from ray._private.ray_constants import (
+ DEFAULT_OBJECT_STORE_MEMORY_PROPORTION,
+ OBJECT_STORE_MINIMUM_MEMORY_BYTES,
+ )
+
+ warning_msg = None
+
available_physical_mem_per_node = int(
- physical_mem_bytes / num_task_slots * _MEMORY_BUFFER_OFFSET
+ physical_mem_bytes / num_task_slots * _RAY_ON_SPARK_WORKER_MEMORY_BUFFER_OFFSET
)
available_shared_mem_per_node = int(
- shared_mem_bytes / num_task_slots * _MEMORY_BUFFER_OFFSET
+ shared_mem_bytes / num_task_slots * _RAY_ON_SPARK_WORKER_MEMORY_BUFFER_OFFSET
)
- if object_store_memory_per_node is None:
+
+ object_store_bytes = configured_object_store_bytes or (
+ available_physical_mem_per_node * DEFAULT_OBJECT_STORE_MEMORY_PROPORTION
+ )
+
+ if object_store_bytes > available_shared_mem_per_node:
object_store_bytes = available_shared_mem_per_node
- else:
- object_store_bytes = int(
- min(
- object_store_memory_per_node,
- available_shared_mem_per_node,
- )
+
+ object_store_bytes_upper_bound = (
+ available_physical_mem_per_node
+ * _RAY_ON_SPARK_MAX_OBJECT_STORE_MEMORY_PROPORTION
+ )
+
+ if object_store_bytes > object_store_bytes_upper_bound:
+ object_store_bytes = object_store_bytes_upper_bound
+ warning_msg = (
+ "Your configured `object_store_memory_per_node` value "
+ "is too high and it is capped by 80% of per-Ray node "
+ "allocated memory."
+ )
+
+ if object_store_bytes < OBJECT_STORE_MINIMUM_MEMORY_BYTES:
+ object_store_bytes = OBJECT_STORE_MINIMUM_MEMORY_BYTES
+ warning_msg = (
+ "Your operating system is configured with too small /dev/shm "
+ "size, so `object_store_memory_per_node` value is configured "
+ f"to minimal size ({OBJECT_STORE_MINIMUM_MEMORY_BYTES} bytes),"
+ f"Please increase system /dev/shm size."
)
+
+ object_store_bytes = int(object_store_bytes)
+
heap_mem_bytes = available_physical_mem_per_node - object_store_bytes
- return heap_mem_bytes, object_store_bytes
+ return heap_mem_bytes, object_store_bytes, warning_msg
+
+# User can manually set these environment variables
+# if ray on spark code accessing corresponding information failed.
+# Note these environment variables must be set in spark executor side,
+# you should set them via setting spark config of
+# `spark.executorEnv.[EnvironmentVariableName]`
+RAY_ON_SPARK_WORKER_CPU_CORES = "RAY_ON_SPARK_WORKER_CPU_CORES"
+RAY_ON_SPARK_WORKER_GPU_NUM = "RAY_ON_SPARK_WORKER_GPU_NUM"
+RAY_ON_SPARK_WORKER_PHYSICAL_MEMORY_BYTES = "RAY_ON_SPARK_WORKER_PHYSICAL_MEMORY_BYTES"
+RAY_ON_SPARK_WORKER_SHARED_MEMORY_BYTES = "RAY_ON_SPARK_WORKER_SHARED_MEMORY_BYTES"
-def get_avail_mem_per_ray_worker_node(spark, object_store_memory_per_node):
+
+def _get_cpu_cores():
+ import multiprocessing
+
+ if RAY_ON_SPARK_WORKER_CPU_CORES in os.environ:
+ # In some cases, spark standalone cluster might configure virtual cpu cores
+ # for spark worker that different with number of physical cpu cores,
+ # but we cannot easily get the virtual cpu cores configured for spark
+ # worker, as a workaround, we provide an environmental variable config
+ # `RAY_ON_SPARK_WORKER_CPU_CORES` for user.
+ return int(os.environ[RAY_ON_SPARK_WORKER_CPU_CORES])
+
+ return multiprocessing.cpu_count()
+
+
+def _get_num_physical_gpus():
+ if RAY_ON_SPARK_WORKER_GPU_NUM in os.environ:
+ # In some cases, spark standalone cluster might configure part of physical
+ # GPUs for spark worker,
+ # but we cannot easily get related configuration,
+ # as a workaround, we provide an environmental variable config
+ # `RAY_ON_SPARK_WORKER_CPU_CORES` for user.
+ return int(os.environ[RAY_ON_SPARK_WORKER_GPU_NUM])
+
+ try:
+ completed_proc = subprocess.run(
+ "nvidia-smi --query-gpu=name --format=csv,noheader",
+ shell=True,
+ check=True,
+ text=True,
+ capture_output=True,
+ )
+ except Exception as e:
+ raise RuntimeError(
+ "Running command `nvidia-smi` for inferring GPU devices list failed."
+ ) from e
+ return len(completed_proc.stdout.strip().split("\n"))
+
+
+def _get_avail_mem_per_ray_worker_node(
+ num_cpus_per_node,
+ num_gpus_per_node,
+ object_store_memory_per_node,
+):
+ num_cpus = _get_cpu_cores()
+ num_task_slots = num_cpus // num_cpus_per_node
+
+ if num_gpus_per_node > 0:
+ num_gpus = _get_num_physical_gpus()
+ if num_task_slots > num_gpus // num_gpus_per_node:
+ num_task_slots = num_gpus // num_gpus_per_node
+
+ physical_mem_bytes = _get_total_physical_memory()
+ shared_mem_bytes = _get_total_shared_memory()
+
+ (
+ ray_worker_node_heap_mem_bytes,
+ ray_worker_node_object_store_bytes,
+ warning_msg,
+ ) = _calc_mem_per_ray_worker_node(
+ num_task_slots,
+ physical_mem_bytes,
+ shared_mem_bytes,
+ object_store_memory_per_node,
+ )
+ return (
+ ray_worker_node_heap_mem_bytes,
+ ray_worker_node_object_store_bytes,
+ None,
+ warning_msg,
+ )
+
+
+def get_avail_mem_per_ray_worker_node(
+ spark,
+ object_store_memory_per_node,
+ num_cpus_per_node,
+ num_gpus_per_node,
+):
"""
- Return the available heap memory and object store memory for each ray worker.
+ Return the available heap memory and object store memory for each ray worker,
+ and error / warning message if it has.
+ Return value is a tuple of
+ (ray_worker_node_heap_mem_bytes, ray_worker_node_object_store_bytes,
+ error_message, warning_message)
NB: We have one ray node per spark task.
"""
- num_cpus_per_spark_task = int(
- spark.sparkContext.getConf().get("spark.task.cpus", "1")
- )
def mapper(_):
try:
- num_cpus = _get_cpu_cores()
- num_task_slots = num_cpus // num_cpus_per_spark_task
-
- physical_mem_bytes = _get_total_physical_memory()
- shared_mem_bytes = _get_total_shared_memory()
-
- (
- ray_worker_node_heap_mem_bytes,
- ray_worker_node_object_store_bytes,
- ) = _calc_mem_per_ray_worker_node(
- num_task_slots,
- physical_mem_bytes,
- shared_mem_bytes,
+ return _get_avail_mem_per_ray_worker_node(
+ num_cpus_per_node,
+ num_gpus_per_node,
object_store_memory_per_node,
)
- return (
- ray_worker_node_heap_mem_bytes,
- ray_worker_node_object_store_bytes,
- None,
- )
except Exception as e:
- return -1, -1, repr(e)
+ return -1, -1, repr(e), None
# Running memory inference routine on spark executor side since the spark worker
# nodes may have a different machine configuration compared to the spark driver
@@ -223,14 +339,22 @@ def mapper(_):
inferred_ray_worker_node_heap_mem_bytes,
inferred_ray_worker_node_object_store_bytes,
err,
+ warning_msg,
) = (
spark.sparkContext.parallelize([1], 1).map(mapper).collect()[0]
)
if err is not None:
raise RuntimeError(
- f"Inferring ray worker available memory failed, error: {err}"
+ f"Inferring ray worker node available memory failed, error: {err}. "
+ "You can bypass this error by setting following spark configs: "
+ "spark.executorEnv.RAY_ON_SPARK_WORKER_CPU_CORES, "
+ "spark.executorEnv.RAY_ON_SPARK_WORKER_GPU_NUM, "
+ "spark.executorEnv.RAY_ON_SPARK_WORKER_PHYSICAL_MEMORY_BYTES, "
+ "spark.executorEnv.RAY_ON_SPARK_WORKER_SHARED_MEMORY_BYTES."
)
+ if warning_msg is not None:
+ _logger.warning(warning_msg)
return (
inferred_ray_worker_node_heap_mem_bytes,
inferred_ray_worker_node_object_store_bytes,
diff --git a/python/requirements_test.txt b/python/requirements_test.txt
index bf1de0d80f46b..256c4d7130aa9 100644
--- a/python/requirements_test.txt
+++ b/python/requirements_test.txt
@@ -53,7 +53,8 @@ PyOpenSSL==22.1.0
pygame==2.1.2; python_version < '3.11'
Pygments==2.13.0
pymongo==4.3.2
-pyspark==3.3.1
+# TODO: Replace this with pyspark==3.4 once it is released.
+https://ml-team-public-read.s3.us-west-2.amazonaws.com/pyspark-3.4.0.dev0.tar.gz
pytest==7.0.1
pytest-asyncio==0.16.0
pytest-rerunfailures==10.2
diff --git a/release/air_tests/air_benchmarks/workloads/xgboost_benchmark.py b/release/air_tests/air_benchmarks/workloads/xgboost_benchmark.py
index 19d41cfa844d5..98d903b88a9d1 100644
--- a/release/air_tests/air_benchmarks/workloads/xgboost_benchmark.py
+++ b/release/air_tests/air_benchmarks/workloads/xgboost_benchmark.py
@@ -112,8 +112,13 @@ def run_xgboost_prediction(model_path: str, data_path: str):
ds = data.read_parquet(data_path)
ckpt = XGBoostCheckpoint.from_model(booster=model)
batch_predictor = BatchPredictor.from_checkpoint(ckpt, XGBoostPredictor)
+ # TODO(https://github.com/ray-project/ray/issues/31723): Once autoscaling
+ # is supported in new execution backend's actor pool, we should remove the
+ # min_scoring_workers and max_scoring_workers.
result = batch_predictor.predict(
ds.drop_columns(["labels"]),
+ min_scoring_workers=10,
+ max_scoring_workers=10,
# Improve prediction throughput for xgboost with larger
# batch size than default 4096
batch_size=8192,
diff --git a/release/nightly_tests/many_nodes_tests/actor_test.py b/release/nightly_tests/many_nodes_tests/actor_test.py
index 95c7475ae4812..620db9cefbd8b 100644
--- a/release/nightly_tests/many_nodes_tests/actor_test.py
+++ b/release/nightly_tests/many_nodes_tests/actor_test.py
@@ -34,11 +34,11 @@ def main():
args, unknown = parse_script_args()
ray.init(address="auto")
-
actor_launch_start = perf_counter()
actors = test_max_actors_launch(args.cpus_per_actor, args.total_actors)
actor_launch_end = perf_counter()
actor_launch_time = actor_launch_end - actor_launch_start
+
if args.fail:
sleep(10)
return
diff --git a/release/nightly_tests/many_nodes_tests/compute_config.yaml b/release/nightly_tests/many_nodes_tests/compute_config.yaml
index 7b9489c46b4ce..8a0e63ebb518d 100644
--- a/release/nightly_tests/many_nodes_tests/compute_config.yaml
+++ b/release/nightly_tests/many_nodes_tests/compute_config.yaml
@@ -1,4 +1,4 @@
-cloud_id: cld_4F7k8814aZzGG8TNUGPKnc
+cloud_id: cld_kvedZWag2qA8i5BjxUevf5i7
region: us-west-2
diff --git a/release/nightly_tests/many_nodes_tests/multi_master_test.py b/release/nightly_tests/many_nodes_tests/multi_master_test.py
new file mode 100644
index 0000000000000..b3f706424eb5a
--- /dev/null
+++ b/release/nightly_tests/many_nodes_tests/multi_master_test.py
@@ -0,0 +1,91 @@
+import argparse
+import os
+from time import sleep, perf_counter
+import json
+import ray
+
+
+def test_max_actors_launch(cpus_per_actor, total_actors, num_masters):
+ # By default, there are 50 groups, each group has 1 master and 99 slaves.
+ num_slaves_per_master = total_actors / num_masters - 1
+
+ @ray.remote(num_cpus=cpus_per_actor)
+ class Actor:
+ def foo(self):
+ pass
+
+ def create(self):
+ return [
+ Actor.options(max_restarts=-1).remote()
+ for _ in range(num_slaves_per_master)
+ ]
+
+ print("Start launch actors")
+ # The 50 masters are spreaded.
+ actors = [
+ Actor.options(max_restarts=-1, scheduling_strategy="SPREAD").remote()
+ for _ in range(num_masters)
+ ]
+ slaves_per_master = []
+ for master in actors:
+ slaves_per_master.append(master.create.remote())
+ for slaves in slaves_per_master:
+ actors.extend(ray.get(slaves))
+ return actors
+
+
+def test_actor_ready(actors):
+ remaining = [actor.foo.remote() for actor in actors]
+ ray.get(remaining)
+
+
+def parse_script_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--cpus-per-actor", type=float, default=0.2)
+ parser.add_argument("--total-actors", type=int, default=5000)
+ parser.add_argument("--num-masters", type=int, default=50)
+ parser.add_argument("--no-report", default=False, action="store_true")
+ parser.add_argument("--fail", default=False, action="store_true")
+ return parser.parse_known_args()
+
+
+def main():
+ args, unknown = parse_script_args()
+
+ ray.init(address="auto")
+
+ actor_launch_start = perf_counter()
+ actors = test_max_actors_launch(
+ args.cpus_per_actor, args.total_actors, args.num_masters
+ )
+ actor_launch_end = perf_counter()
+ actor_launch_time = actor_launch_end - actor_launch_start
+ if args.fail:
+ sleep(10)
+ return
+ actor_ready_start = perf_counter()
+ test_actor_ready(actors)
+ actor_ready_end = perf_counter()
+ actor_ready_time = actor_ready_end - actor_ready_start
+
+ print(f"Actor launch time: {actor_launch_time} ({args.total_actors} actors)")
+ print(f"Actor ready time: {actor_ready_time} ({args.total_actors} actors)")
+ print(
+ f"Total time: {actor_launch_time + actor_ready_time}"
+ f" ({args.total_actors} actors)"
+ )
+
+ if "TEST_OUTPUT_JSON" in os.environ and not args.no_report:
+ out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
+ results = {
+ "actor_launch_time": actor_launch_time,
+ "actor_ready_time": actor_ready_time,
+ "total_time": actor_launch_time + actor_ready_time,
+ "num_actors": args.total_actors,
+ "success": "1",
+ }
+ json.dump(results, out_file)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/release/nightly_tests/stress_tests/test_state_api_scale.py b/release/nightly_tests/stress_tests/test_state_api_scale.py
index b47e8309ad5dc..9972fae1f6d18 100644
--- a/release/nightly_tests/stress_tests/test_state_api_scale.py
+++ b/release/nightly_tests/stress_tests/test_state_api_scale.py
@@ -1,7 +1,7 @@
import click
import json
import ray
-from ray._private.ray_constants import LOG_PREFIX_ACTOR_NAME
+from ray._private.ray_constants import LOG_PREFIX_ACTOR_NAME, LOG_PREFIX_JOB_ID
from ray._private.state_api_test_utils import (
STATE_LIST_LIMIT,
StateAPIMetric,
@@ -62,7 +62,7 @@ def test_many_tasks(num_tasks: int):
invoke_state_api_n(
lambda res: len(res) == 0,
list_tasks,
- filters=[("name", "=", "pi4_sample"), ("scheduling_state", "=", "RUNNING")],
+ filters=[("name", "=", "pi4_sample"), ("state", "=", "RUNNING")],
key_suffix="0",
limit=STATE_LIST_LIMIT,
err_msg="Expect 0 running tasks.",
@@ -97,7 +97,7 @@ def pi4_sample(signal):
invoke_state_api_n(
lambda res: len(res) == num_tasks,
list_tasks,
- filters=[("name", "=", "pi4_sample"), ("scheduling_state", "!=", "FINISHED")],
+ filters=[("name", "=", "pi4_sample"), ("state", "!=", "FINISHED")],
key_suffix=f"{num_tasks}",
limit=STATE_LIST_LIMIT,
err_msg=f"Expect {num_tasks} non finished tasks.",
@@ -112,7 +112,7 @@ def pi4_sample(signal):
invoke_state_api_n(
lambda res: len(res) == 0,
list_tasks,
- filters=[("name", "=", "pi4_sample"), ("scheduling_state", "=", "RUNNING")],
+ filters=[("name", "=", "pi4_sample"), ("state", "=", "RUNNING")],
key_suffix="0",
limit=STATE_LIST_LIMIT,
err_msg="Expect 0 running tasks",
@@ -251,7 +251,8 @@ def test_large_log_file(log_file_size_byte: int):
class LogActor:
def write_log(self, log_file_size_byte: int):
ctx = hashlib.md5()
- prefix = f"{LOG_PREFIX_ACTOR_NAME}LogActor\n"
+ job_id = ray.get_runtime_context().get_job_id()
+ prefix = f"{LOG_PREFIX_JOB_ID}{job_id}\n{LOG_PREFIX_ACTOR_NAME}LogActor\n"
ctx.update(prefix.encode())
while log_file_size_byte > 0:
n = min(log_file_size_byte, 4 * MiB)
diff --git a/release/ray_release/command_runner/command_runner.py b/release/ray_release/command_runner/command_runner.py
index 60f26ec5b7275..ad4c8d6a476cd 100644
--- a/release/ray_release/command_runner/command_runner.py
+++ b/release/ray_release/command_runner/command_runner.py
@@ -3,6 +3,8 @@
from ray_release.cluster_manager.cluster_manager import ClusterManager
from ray_release.file_manager.file_manager import FileManager
+from ray_release.util import exponential_backoff_retry
+from click.exceptions import ClickException
class CommandRunner(abc.ABC):
@@ -85,7 +87,12 @@ def run_prepare_command(
Command runners may choose to run this differently than the
test command.
"""
- return self.run_command(command, env, timeout)
+ return exponential_backoff_retry(
+ lambda: self.run_command(command, env, timeout),
+ ClickException,
+ initial_retry_delay_s=5,
+ max_retries=3,
+ )
def get_last_logs(self):
raise NotImplementedError
diff --git a/release/ray_release/file_manager/job_file_manager.py b/release/ray_release/file_manager/job_file_manager.py
index ddcc1e94dd384..39f34b5dbd178 100644
--- a/release/ray_release/file_manager/job_file_manager.py
+++ b/release/ray_release/file_manager/job_file_manager.py
@@ -26,8 +26,9 @@ def __init__(self, cluster_manager: ClusterManager):
self.s3_client = boto3.client("s3")
self.bucket = str(RELEASE_AWS_BUCKET)
self.job_manager = JobManager(cluster_manager)
-
- sys.path.insert(0, f"{anyscale.ANYSCALE_RAY_DIR}/bin")
+ # Backward compatible
+ if "ANYSCALE_RAY_DIR" in anyscale.__dict__:
+ sys.path.insert(0, f"{anyscale.ANYSCALE_RAY_DIR}/bin")
def _run_with_retry(self, f, initial_retry_delay_s: int = 10):
assert callable(f)
diff --git a/release/ray_release/logger.py b/release/ray_release/logger.py
index 11f7b90a97665..8d79a4368c5b2 100644
--- a/release/ray_release/logger.py
+++ b/release/ray_release/logger.py
@@ -1,7 +1,7 @@
import logging
import sys
-logger = logging.getLogger(__name__)
+logger = logging.getLogger()
logger.setLevel(logging.INFO)
diff --git a/release/release_tests.yaml b/release/release_tests.yaml
index 879cd0132ef07..bb830a6db379b 100644
--- a/release/release_tests.yaml
+++ b/release/release_tests.yaml
@@ -3881,17 +3881,40 @@
frequency: nightly-3x
team: core
+ env: staging
cluster:
cluster_env: many_nodes_tests/app_config.yaml
cluster_compute: many_nodes_tests/compute_config.yaml
run:
timeout: 7200
- script: python many_nodes_tests/actor_test.py
+ # 4cpus per node x 250 nodes / 0.2 cpus per actor = 5k
+ script: python many_nodes_tests/actor_test.py --cpus-per-actor=0.2 --total-actors=5000
wait_for_nodes:
num_nodes: 251
+ type: job
-
+#- name: many_nodes_multi_master_test
+# group: core-daily-test
+# working_dir: nightly_tests
+# legacy:
+# test_name: many_nodes_multi_master_test
+# test_suite: nightly_tests
+#
+# frequency: nightly-3x
+# team: core
+# cluster:
+# cluster_env: many_nodes_tests/app_config.yaml
+# cluster_compute: many_nodes_tests/compute_config.yaml
+#
+# run:
+# timeout: 7200
+# script: python many_nodes_tests/multi_master_test.py
+# wait_for_nodes:
+# num_nodes: 251
+#
+# type: sdk_command
+# file_manager: sdk
- name: pg_autoscaling_regression_test
group: core-daily-test
@@ -4370,7 +4393,6 @@
script: python iter_tensor_batches_benchmark.py
type: sdk_command
- file_manager: sdk
- name: iter_batches_benchmark_single_node
group: data-tests
diff --git a/rllib/__init__.py b/rllib/__init__.py
index c57d3094fb27c..f63b8173d4338 100644
--- a/rllib/__init__.py
+++ b/rllib/__init__.py
@@ -1,3 +1,5 @@
+import logging
+
from ray._private.usage import usage_lib
# Note: do not introduce unnecessary library dependencies here, e.g. gym.
@@ -14,6 +16,18 @@
from ray.tune.registry import register_trainable
+def _setup_logger():
+ logger = logging.getLogger("ray.rllib")
+ handler = logging.StreamHandler()
+ handler.setFormatter(
+ logging.Formatter(
+ "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s"
+ )
+ )
+ logger.addHandler(handler)
+ logger.propagate = False
+
+
def _register_all():
from ray.rllib.algorithms.registry import ALGORITHMS, _get_algorithm_class
@@ -24,6 +38,8 @@ def _register_all():
register_trainable(key, _get_algorithm_class(key))
+_setup_logger()
+
usage_lib.record_library_usage("rllib")
__all__ = [
diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py
index 47d50a12db8c4..81859387d0054 100644
--- a/rllib/algorithms/algorithm.py
+++ b/rllib/algorithms/algorithm.py
@@ -35,6 +35,12 @@
import ray.cloudpickle as pickle
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
+from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
+from ray.rllib.core.rl_module.marl_module import (
+ MultiAgentRLModuleSpec,
+ MultiAgentRLModule,
+)
+
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
from ray.rllib.algorithms.registry import ALGORITHMS as ALL_ALGORITHMS
from ray.rllib.env.env_context import EnvContext
@@ -493,7 +499,17 @@ def setup(self, config: AlgorithmConfig) -> None:
self._record_usage(self.config)
- self.callbacks = self.config["callbacks"]()
+ # Create the callbacks object.
+ self.callbacks = self.config.callbacks_class()
+
+ if self.config.log_level in ["WARN", "ERROR"]:
+ logger.info(
+ f"Current log_level is {self.config.log_level}. For more information, "
+ "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
+ "-vv flags."
+ )
+ if self.config.log_level:
+ logging.getLogger("ray.rllib").setLevel(self.config.log_level)
# Create local replay buffer if necessary.
self.local_replay_buffer = self._create_local_replay_buffer_if_necessary(
@@ -523,7 +539,7 @@ def setup(self, config: AlgorithmConfig) -> None:
error=True,
help="Running OPE during training is not recommended.",
)
- self.config["off_policy_estimation_methods"] = ope_dict
+ self.config.off_policy_estimation_methods = ope_dict
# Deprecated way of implementing Trainer sub-classes (or "templates"
# via the `build_trainer` utility function).
@@ -557,14 +573,14 @@ def setup(self, config: AlgorithmConfig) -> None:
validate_env=self.validate_env,
default_policy_class=self.get_default_policy_class(self.config),
config=self.config,
- num_workers=self.config["num_workers"],
+ num_workers=self.config.num_rollout_workers,
local_worker=True,
logdir=self.logdir,
)
# TODO (avnishn): Remove the execution plan API by q1 2023
# Function defining one single training iteration's behavior.
- if self.config["_disable_execution_plan_api"]:
+ if self.config._disable_execution_plan_api:
# Ensure remote workers are initially in sync with the local worker.
self.workers.sync_weights()
# LocalIterator-creating "execution plan".
@@ -605,13 +621,13 @@ def setup(self, config: AlgorithmConfig) -> None:
validate_env=None,
default_policy_class=self.get_default_policy_class(self.config),
config=self.evaluation_config,
- num_workers=self.config["evaluation_num_workers"],
+ num_workers=self.config.evaluation_num_workers,
# Don't even create a local worker if num_workers > 0.
local_worker=False,
logdir=self.logdir,
)
- if self.config["enable_async_evaluation"]:
+ if self.config.enable_async_evaluation:
self._evaluation_weights_seq_number = 0
self.evaluation_dataset = None
@@ -641,7 +657,7 @@ def setup(self, config: AlgorithmConfig) -> None:
"dm": DirectMethod,
"dr": DoublyRobust,
}
- for name, method_config in self.config["off_policy_estimation_methods"].items():
+ for name, method_config in self.config.off_policy_estimation_methods.items():
method_type = method_config.pop("type")
if method_type in ope_types:
deprecation_warning(
@@ -662,7 +678,7 @@ def setup(self, config: AlgorithmConfig) -> None:
# offline evaluators.
policy = self.get_policy()
if issubclass(method_type, OffPolicyEstimator):
- method_config["gamma"] = self.config["gamma"]
+ method_config["gamma"] = self.config.gamma
self.reward_estimators[name] = method_type(policy, **method_config)
else:
raise ValueError(
@@ -676,12 +692,27 @@ def setup(self, config: AlgorithmConfig) -> None:
self.trainer_runner = None
if self.config._enable_rl_trainer_api:
- policy = self.get_policy()
- observation_space = policy.observation_space
- action_space = policy.action_space
- trainer_runner_config = self.config.get_trainer_runner_config(
- observation_space, action_space
+ # TODO (Kourosh): This is an interim solution where policies and modules
+ # co-exist. In this world we have both policy_map and MARLModule that need
+ # to be consistent with one another. To make a consistent parity between
+ # the two we need to loop throught the policy modules and create a simple
+ # MARLModule from the RLModule within each policy.
+ local_worker = self.workers.local_worker()
+ module_specs = {}
+
+ for pid, policy in local_worker.policy_map.items():
+ module_specs[pid] = SingleAgentRLModuleSpec(
+ module_class=policy.config["rl_module_class"],
+ observation_space=policy.observation_space,
+ action_space=policy.action_space,
+ model_config=policy.config["model"],
+ )
+
+ module_spec = MultiAgentRLModuleSpec(
+ module_class=MultiAgentRLModule, module_specs=module_specs
)
+
+ trainer_runner_config = self.config.get_trainer_runner_config(module_spec)
self.trainer_runner = trainer_runner_config.build()
# Run `on_algorithm_init` callback after initialization is done.
@@ -739,7 +770,7 @@ def step(self) -> ResultDict:
results: ResultDict = {}
# Parallel eval + training: Kick off evaluation-loop and parallel train() call.
- if evaluate_this_iter and self.config["evaluation_parallel_to_training"]:
+ if evaluate_this_iter and self.config.evaluation_parallel_to_training:
(
results,
train_iter_ctx,
@@ -751,12 +782,12 @@ def step(self) -> ResultDict:
results, train_iter_ctx = self._run_one_training_iteration()
# Sequential: Train (already done above), then evaluate.
- if evaluate_this_iter and not self.config["evaluation_parallel_to_training"]:
+ if evaluate_this_iter and not self.config.evaluation_parallel_to_training:
results.update(self._run_one_evaluation(train_future=None))
# Attach latest available evaluation results to train results,
# if necessary.
- if not evaluate_this_iter and self.config["always_attach_evaluation_results"]:
+ if not evaluate_this_iter and self.config.always_attach_evaluation_results:
assert isinstance(
self.evaluation_metrics, dict
), "Trainer.evaluate() needs to return a dict."
@@ -767,17 +798,15 @@ def step(self) -> ResultDict:
self._sync_filters_if_needed(
from_worker=self.workers.local_worker(),
workers=self.workers,
- timeout_seconds=self.config[
- "sync_filters_on_rollout_workers_timeout_s"
- ],
+ timeout_seconds=self.config.sync_filters_on_rollout_workers_timeout_s,
)
# TODO (avnishn): Remove the execution plan API by q1 2023
# Collect worker metrics and add combine them with `results`.
- if self.config["_disable_execution_plan_api"]:
+ if self.config._disable_execution_plan_api:
episodes_this_iter = collect_episodes(
self.workers,
self._remote_worker_ids_for_metrics(),
- timeout_seconds=self.config["metrics_episode_collection_timeout_s"],
+ timeout_seconds=self.config.metrics_episode_collection_timeout_s,
)
results = self._compile_iteration_results(
episodes_this_iter=episodes_this_iter,
@@ -786,8 +815,8 @@ def step(self) -> ResultDict:
)
# Check `env_task_fn` for possible update of the env's task.
- if self.config["env_task_fn"] is not None:
- if not callable(self.config["env_task_fn"]):
+ if self.config.env_task_fn is not None:
+ if not callable(self.config.env_task_fn):
raise ValueError(
"`env_task_fn` must be None or a callable taking "
"[train_results, env, env_ctx] as args!"
@@ -799,7 +828,7 @@ def fn(env, env_context, task_fn):
if cur_task != new_task:
env.set_task(new_task)
- fn = functools.partial(fn, task_fn=self.config["env_task_fn"])
+ fn = functools.partial(fn, task_fn=self.config.env_task_fn)
self.workers.foreach_env_with_context(fn)
return results
@@ -834,20 +863,20 @@ def evaluate(
self._sync_filters_if_needed(
from_worker=self.workers.local_worker(),
workers=self.evaluation_workers,
- timeout_seconds=self.config[
- "sync_filters_on_rollout_workers_timeout_s"
- ],
+ timeout_seconds=self.config.sync_filters_on_rollout_workers_timeout_s,
)
self.callbacks.on_evaluate_start(algorithm=self)
- if self.config["custom_eval_function"]:
+ if self.config.custom_evaluation_function:
logger.info(
"Running custom eval function {}".format(
- self.config["custom_eval_function"]
+ self.config.custom_evaluation_function
)
)
- metrics = self.config["custom_eval_function"](self, self.evaluation_workers)
+ metrics = self.config.custom_evaluation_function(
+ self, self.evaluation_workers
+ )
if not metrics or not isinstance(metrics, dict):
raise ValueError(
"Custom eval function must return "
@@ -872,15 +901,15 @@ def evaluate(
# How many episodes/timesteps do we need to run?
# In "auto" mode (only for parallel eval + training): Run as long
# as training lasts.
- unit = self.config["evaluation_duration_unit"]
+ unit = self.config.evaluation_duration_unit
eval_cfg = self.evaluation_config
- rollout = eval_cfg["rollout_fragment_length"]
- num_envs = eval_cfg["num_envs_per_worker"]
- auto = self.config["evaluation_duration"] == "auto"
+ rollout = eval_cfg.rollout_fragment_length
+ num_envs = eval_cfg.num_envs_per_worker
+ auto = self.config.evaluation_duration == "auto"
duration = (
- self.config["evaluation_duration"]
+ self.config.evaluation_duration
if not auto
- else (self.config["evaluation_num_workers"] or 1)
+ else (self.config.evaluation_num_workers or 1)
* (1 if unit == "episodes" else rollout)
)
agent_steps_this_iter = 0
@@ -893,7 +922,7 @@ def evaluate(
def duration_fn(num_units_done):
return duration - num_units_done
- logger.info(f"Evaluating current policy for {duration} {unit}.")
+ logger.info(f"Evaluating current state of {self} for {duration} {unit}.")
metrics = None
all_batches = []
@@ -914,8 +943,8 @@ def duration_fn(num_units_done):
all_batches.append(batch)
metrics = collect_metrics(
self.workers,
- keep_custom_metrics=eval_cfg["keep_per_episode_custom_metrics"],
- timeout_seconds=eval_cfg["metrics_episode_collection_timeout_s"],
+ keep_custom_metrics=eval_cfg.keep_per_episode_custom_metrics,
+ timeout_seconds=eval_cfg.metrics_episode_collection_timeout_s,
)
# Evaluation worker set only has local worker.
@@ -960,13 +989,13 @@ def duration_fn(num_units_done):
func=lambda w: w.sample(),
local_worker=False,
remote_worker_ids=selected_eval_worker_ids,
- timeout_seconds=self.config["evaluation_sample_timeout_s"],
+ timeout_seconds=self.config.evaluation_sample_timeout_s,
)
if len(batches) != len(selected_eval_worker_ids):
logger.warning(
"Calling `sample()` on your remote evaluation worker(s) "
"resulted in a timeout (after the configured "
- f"{self.config['evaluation_sample_timeout_s']} seconds)! "
+ f"{self.config.evaluation_sample_timeout_s} seconds)! "
"Try to set `evaluation_sample_timeout_s` in your config"
" to a larger value."
+ (
@@ -1005,7 +1034,7 @@ def duration_fn(num_units_done):
env_steps_this_iter += _env_steps
logger.info(
- f"Ran round {_round} of parallel evaluation "
+ f"Ran round {_round} of non-parallel evaluation "
f"({num_units_done}/{duration if not auto else '?'} "
f"{unit} done)"
)
@@ -1017,8 +1046,8 @@ def duration_fn(num_units_done):
if metrics is None:
metrics = collect_metrics(
self.evaluation_workers,
- keep_custom_metrics=self.config["keep_per_episode_custom_metrics"],
- timeout_seconds=eval_cfg["metrics_episode_collection_timeout_s"],
+ keep_custom_metrics=self.config.keep_per_episode_custom_metrics,
+ timeout_seconds=eval_cfg.metrics_episode_collection_timeout_s,
)
metrics[NUM_AGENT_STEPS_SAMPLED_THIS_ITER] = agent_steps_this_iter
metrics[NUM_ENV_STEPS_SAMPLED_THIS_ITER] = env_steps_this_iter
@@ -1032,9 +1061,7 @@ def duration_fn(num_units_done):
for batch in all_batches:
estimate_result = estimator.estimate(
batch,
- split_batch_by_episode=self.config[
- "ope_split_batch_by_episode"
- ],
+ split_batch_by_episode=self.config.ope_split_batch_by_episode,
)
estimates[name].append(estimate_result)
@@ -1084,15 +1111,15 @@ def _evaluate_async(
# How many episodes/timesteps do we need to run?
# In "auto" mode (only for parallel eval + training): Run as long
# as training lasts.
- unit = self.config["evaluation_duration_unit"]
+ unit = self.config.evaluation_duration_unit
eval_cfg = self.evaluation_config
- rollout = eval_cfg["rollout_fragment_length"]
- num_envs = eval_cfg["num_envs_per_worker"]
- auto = self.config["evaluation_duration"] == "auto"
+ rollout = eval_cfg.rollout_fragment_length
+ num_envs = eval_cfg.num_envs_per_worker
+ auto = self.config.evaluation_duration == "auto"
duration = (
- self.config["evaluation_duration"]
+ self.config.evaluation_duration
if not auto
- else (self.config["evaluation_num_workers"] or 1)
+ else (self.config.evaluation_num_workers or 1)
* (1 if unit == "episodes" else rollout)
)
@@ -1103,17 +1130,17 @@ def _evaluate_async(
self._sync_filters_if_needed(
from_worker=self.workers.local_worker(),
workers=self.evaluation_workers,
- timeout_seconds=eval_cfg.get("sync_filters_on_rollout_workers_timeout_s"),
+ timeout_seconds=eval_cfg.sync_filters_on_rollout_workers_timeout_s,
)
- if self.config["custom_eval_function"]:
+ if self.config.custom_evaluation_function:
raise ValueError(
- "`custom_eval_function` not supported in combination "
+ "`config.custom_evaluation_function` not supported in combination "
"with `enable_async_evaluation=True` config setting!"
)
if self.evaluation_workers is None and (
self.workers.local_worker().input_reader is None
- or self.config["evaluation_num_workers"] == 0
+ or self.config.evaluation_num_workers == 0
):
raise ValueError(
"Evaluation w/o eval workers (calling Algorithm.evaluate() w/o "
@@ -1126,7 +1153,7 @@ def _evaluate_async(
agent_steps_this_iter = 0
env_steps_this_iter = 0
- logger.info(f"Evaluating current policy for {duration} {unit}.")
+ logger.info(f"Evaluating current state of {self} for {duration} {unit}.")
all_batches = []
@@ -1310,11 +1337,11 @@ def training_step(self) -> ResultDict:
# Collect SampleBatches from sample workers until we have a full batch.
if self.config.count_steps_by == "agent_steps":
train_batch = synchronous_parallel_sample(
- worker_set=self.workers, max_agent_steps=self.config["train_batch_size"]
+ worker_set=self.workers, max_agent_steps=self.config.train_batch_size
)
else:
train_batch = synchronous_parallel_sample(
- worker_set=self.workers, max_env_steps=self.config["train_batch_size"]
+ worker_set=self.workers, max_env_steps=self.config.train_batch_size
)
train_batch = train_batch.as_multi_agent()
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
@@ -1417,17 +1444,17 @@ def compute_single_action(
full_fetch: Whether to return extra action fetch results.
This is always set to True if `state` is specified.
explore: Whether to apply exploration to the action.
- Default: None -> use self.config["explore"].
+ Default: None -> use self.config.explore.
timestep: The current (sampling) time step.
episode: This provides access to all of the internal episodes'
state, which may be useful for model-based or multi-agent
algorithms.
unsquash_action: Should actions be unsquashed according to the
env's/Policy's action space? If None, use the value of
- self.config["normalize_actions"].
+ self.config.normalize_actions.
clip_action: Should actions be clipped according to the
env's/Policy's action space? If None, use the value of
- self.config["clip_actions"].
+ self.config.clip_actions.
Keyword Args:
kwargs: forward compatibility placeholder
@@ -1458,10 +1485,10 @@ def compute_single_action(
# `unsquash_action` is None: Use value of config['normalize_actions'].
if unsquash_action is None:
- unsquash_action = self.config["normalize_actions"]
+ unsquash_action = self.config.normalize_actions
# `clip_action` is None: Use value of config['clip_actions'].
elif clip_action is None:
- clip_action = self.config["clip_actions"]
+ clip_action = self.config.clip_actions
# User provided an input-dict: Assert that `obs`, `prev_a|r`, `state`
# are all None.
@@ -1612,17 +1639,17 @@ def compute_actions(
full_fetch: Whether to return extra action fetch results.
This is always set to True if RNN state is specified.
explore: Whether to pick an exploitation or exploration
- action (default: None -> use self.config["explore"]).
+ action (default: None -> use self.config.explore).
timestep: The current (sampling) time step.
episodes: This provides access to all of the internal episodes'
state, which may be useful for model-based or multi-agent
algorithms.
unsquash_actions: Should actions be unsquashed according
to the env's/Policy's action space? If None, use
- self.config["normalize_actions"].
+ self.config.normalize_actions.
clip_actions: Should actions be clipped according to the
env's/Policy's action space? If None, use
- self.config["clip_actions"].
+ self.config.clip_actions.
Keyword Args:
kwargs: forward compatibility placeholder
@@ -1642,10 +1669,10 @@ def compute_actions(
# `unsquash_actions` is None: Use value of config['normalize_actions'].
if unsquash_actions is None:
- unsquash_actions = self.config["normalize_actions"]
+ unsquash_actions = self.config.normalize_actions
# `clip_actions` is None: Use value of config['clip_actions'].
elif clip_actions is None:
- clip_actions = self.config["clip_actions"]
+ clip_actions = self.config.clip_actions
# Preprocess obs and states.
state_defined = state is not None
@@ -2259,7 +2286,7 @@ def _sync_filters_if_needed(
FilterManager.synchronize(
from_worker.filters,
workers,
- update_remote=self.config["synchronize_filters"],
+ update_remote=self.config.synchronize_filters,
timeout_seconds=timeout_seconds,
)
logger.debug("synchronized filters: {}".format(from_worker.filters))
@@ -2803,11 +2830,11 @@ def _should_create_evaluation_rollout_workers(cls, eval_config: "AlgorithmConfig
Returns False if we need to run offline evaluation
(with ope.estimate_on_dastaset API) or when local worker is to be used for
evaluation. Note: We only use estimate_on_dataset API with bandits for now.
- That is when ope_split_batch_by_episode is False. TODO: In future we will do
- the same for episodic RL OPE.
+ That is when ope_split_batch_by_episode is False.
+ TODO: In future we will do the same for episodic RL OPE.
"""
run_offline_evaluation = (
- eval_config.get("off_policy_estimation_methods")
+ eval_config.off_policy_estimation_methods
and not eval_config.ope_split_batch_by_episode
)
return not run_offline_evaluation and (
@@ -2857,7 +2884,7 @@ def _compile_iteration_results(
# Calculate how many (if any) of older, historical episodes we have to add to
# `episodes_this_iter` in order to reach the required smoothing window.
episodes_for_metrics = episodes_this_iter[:]
- missing = self.config["metrics_num_episodes_for_smoothing"] - len(
+ missing = self.config.metrics_num_episodes_for_smoothing - len(
episodes_this_iter
)
# We have to add some older episodes to reach the smoothing window size.
@@ -2865,7 +2892,7 @@ def _compile_iteration_results(
episodes_for_metrics = self._episode_history[-missing:] + episodes_this_iter
assert (
len(episodes_for_metrics)
- <= self.config["metrics_num_episodes_for_smoothing"]
+ <= self.config.metrics_num_episodes_for_smoothing
)
# Note that when there are more than `metrics_num_episodes_for_smoothing`
# episodes in `episodes_for_metrics`, leave them as-is. In this case, we'll
@@ -2875,12 +2902,12 @@ def _compile_iteration_results(
# needed.
self._episode_history.extend(episodes_this_iter)
self._episode_history = self._episode_history[
- -self.config["metrics_num_episodes_for_smoothing"] :
+ -self.config.metrics_num_episodes_for_smoothing :
]
results["sampler_results"] = summarize_episodes(
episodes_for_metrics,
episodes_this_iter,
- self.config["keep_per_episode_custom_metrics"],
+ self.config.keep_per_episode_custom_metrics,
)
# TODO: Don't dump sampler results into top-level.
results.update(results["sampler_results"])
diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py
index cd720babb0ca6..f2bb985b72efc 100644
--- a/rllib/algorithms/algorithm_config.py
+++ b/rllib/algorithms/algorithm_config.py
@@ -16,7 +16,11 @@
import ray
from ray.rllib.algorithms.callbacks import DefaultCallbacks
-from ray.rllib.core.rl_trainer import TrainerRunnerConfig
+from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
+from ray.rllib.core.rl_trainer.trainer_runner_config import (
+ TrainerRunnerConfig,
+ ModuleSpec,
+)
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.multi_agent_env import MultiAgentEnv
@@ -385,7 +389,7 @@ def __init__(self, algo_class=None):
# `self.debugging()`
self.logger_creator = None
self.logger_config = None
- self.log_level = DEPRECATED_VALUE
+ self.log_level = "WARN"
self.log_sys_usage = True
self.fake_sampler = False
self.seed = None
@@ -2052,7 +2056,7 @@ def debugging(
*,
logger_creator: Optional[Callable[[], Logger]] = NotProvided,
logger_config: Optional[dict] = NotProvided,
- log_level: Optional[str] = DEPRECATED_VALUE,
+ log_level: Optional[str] = NotProvided,
log_sys_usage: Optional[bool] = NotProvided,
fake_sampler: Optional[bool] = NotProvided,
seed: Optional[int] = NotProvided,
@@ -2086,16 +2090,8 @@ def debugging(
self.logger_creator = logger_creator
if logger_config is not NotProvided:
self.logger_config = logger_config
- if log_level != DEPRECATED_VALUE:
- deprecation_warning(
- old="config.log_level",
- help=(
- "RLlib no longer has a separate logging configuration from the rest"
- " of Ray. Configure logging on the root logger; RLlib messages "
- "will be propagated up the logger hierarchy to be handled there."
- ),
- error=False,
- )
+ if log_level is not NotProvided:
+ self.log_level = log_level
if log_sys_usage is not NotProvided:
self.log_sys_usage = log_sys_usage
if fake_sampler is not NotProvided:
@@ -2607,9 +2603,25 @@ def get_default_rl_trainer_class(self) -> Union[Type["RLTrainer"], str]:
raise NotImplementedError
def get_trainer_runner_config(
- self, observation_space: Space, action_space: Space
+ self, module_spec: Optional[ModuleSpec] = None
) -> TrainerRunnerConfig:
+ if module_spec is None:
+ module_spec = SingleAgentRLModuleSpec()
+
+ if isinstance(module_spec, SingleAgentRLModuleSpec):
+ if module_spec.module_class is None:
+ module_spec.module_class = self.rl_module_class
+
+ if module_spec.observation_space is None:
+ module_spec.observation_space = self.observation_space
+
+ if module_spec.action_space is None:
+ module_spec.action_space = self.action_space
+
+ if module_spec.model_config is None:
+ module_spec.model_config = self.model
+
if not self._is_frozen:
raise ValueError(
"Cannot call `get_trainer_runner_config()` on an unfrozen "
@@ -2618,12 +2630,7 @@ def get_trainer_runner_config(
config = (
TrainerRunnerConfig()
- .module(
- module_class=self.rl_module_class,
- observation_space=observation_space,
- action_space=action_space,
- model_config=self.model,
- )
+ .module(module_spec)
.trainer(
trainer_class=self.rl_trainer_class,
eager_tracing=self.eager_tracing,
diff --git a/rllib/core/rl_module/marl_module.py b/rllib/core/rl_module/marl_module.py
index 1a1f4f6d96481..9319795184e12 100644
--- a/rllib/core/rl_module/marl_module.py
+++ b/rllib/core/rl_module/marl_module.py
@@ -1,14 +1,15 @@
import copy
+from dataclasses import dataclass
import pprint
-from typing import Iterator, Mapping, Any, Union, Dict
+from typing import Iterator, Mapping, Any, Union, Dict, Optional, Type
from ray.util.annotations import PublicAPI
-from ray.rllib.utils.annotations import override
+from ray.rllib.utils.annotations import override, ExperimentalAPI
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.models.specs.specs_dict import SpecDict
from ray.rllib.policy.sample_batch import MultiAgentBatch
-from ray.rllib.core.rl_module import RLModule
+from ray.rllib.core.rl_module.rl_module import RLModule, SingleAgentRLModuleSpec
# TODO (Kourosh): change this to module_id later to enforce consistency
from ray.rllib.utils.policy import validate_policy_id
@@ -16,6 +17,23 @@
ModuleID = str
+@ExperimentalAPI
+@dataclass
+class MultiAgentRLModuleSpec:
+ """A utility spec class to make it constructing RLModules (in multi-agent case) easier.
+
+ Args:
+ module_class: ...
+ module_specs: ...
+ """
+
+ module_class: Optional[Type["MultiAgentRLModule"]] = None
+ module_specs: Optional[Dict[ModuleID, SingleAgentRLModuleSpec]] = None
+
+ def build(self) -> "MultiAgentRLModule":
+ return self.module_class.from_multi_agent_config({"modules": self.module_specs})
+
+
def _get_module_configs(config: Dict[str, Any]):
"""Constructs a mapping from module_id to module config.
@@ -26,8 +44,9 @@ def _get_module_configs(config: Dict[str, Any]):
module_specs = config.pop("modules", {})
for common_spec in config:
for module_spec in module_specs.values():
- if common_spec not in module_spec:
- module_spec[common_spec] = config[common_spec]
+ if getattr(module_spec, common_spec) is None:
+ setattr(module_spec, common_spec, config[common_spec])
+
return module_specs
@@ -65,9 +84,8 @@ def from_multi_agent_config(cls, config: Mapping[str, Any]) -> "MultiAgentRLModu
"""Creates a MultiAgentRLModule from a multi-agent config.
The input config should contain "modules" key that is a mapping from module_id
- to the module spec for each RLModule. The module spec should be a dict with the
- following keys: `module_class`, `observation_space`, `action_space`,
- `model_config`. If there are multiple modules that do share the same
+ to the module spec for each RLModule which is a SingleAgentRLModuleSpec object.
+ If there are multiple modules that do share the same
`observation_space`, `action_space`, or `model_config`, you can specify these
keys at the top level of the config, and the module spec will inherit the
values from the top level config.
@@ -78,16 +96,16 @@ def from_multi_agent_config(cls, config: Mapping[str, Any]) -> "MultiAgentRLModu
config = {
"modules": {
- "module_1": {
- "module_class": "RLModule1",
- "observation_space": gym.spaces.Box(...),
- "action_space": gym.spaces.Discrete(...),
- "model_config": {hidden_dim: 256}
- },
- "module_2": {
- "module_class": "RLModule2",
- "observation_space": gym.spaces.Box(...),
- }
+ "module_1": SingleAgentRLModuleSpec(
+ module_class="RLModule1",
+ observation_space=gym.spaces.Box(...),
+ action_space=gym.spaces.Discrete(...),
+ model_config={hidden_dim: 256}
+ )
+ "module_2": SingleAgentRLModuleSpec(
+ module_class="RLModule2",
+ observation_space=gym.spaces.Box(...),
+ )
},
"action_space": gym.spaces.Box(...),
"model_config": {hidden_dim: 32}
@@ -97,17 +115,17 @@ def from_multi_agent_config(cls, config: Mapping[str, Any]) -> "MultiAgentRLModu
config = {
"modules": {
- "module_1": {
- "module_class": "RLModule1",
- "observation_space": gym.spaces.Box(...),
- "action_space": gym.spaces.Discrete(...),
- "model_config": {hidden_dim: 256}
- },
- "module_2": {
- "module_class": "RLModule2",
- "observation_space": gym.spaces.Box(...),
- "action_space": gym.spaces.Box(...), # Inherited
- "model_config": {hidden_dim: 32} # Inherited
+ "module_1": SingleAgentRLModuleSpec(
+ module_class="RLModule1",
+ observation_space=gym.spaces.Box(...),
+ action_space=gym.spaces.Discrete(...),
+ model_config={hidden_dim: 256}
+ )
+ "module_2": SingleAgentRLModuleSpec(
+ module_class="RLModule2",
+ observation_space=gym.spaces.Box(...),
+ action_space=gym.spaces.Box(...), # Inherited
+ model_config={hidden_dim: 32} # Inherited
}
},
}
@@ -126,8 +144,9 @@ def from_multi_agent_config(cls, config: Mapping[str, Any]) -> "MultiAgentRLModu
multiagent_module = cls()
for module_id, module_spec in module_configs.items():
- module_cls: RLModule = module_spec.pop("module_class")
- module = module_cls.from_model_config(**module_spec)
+ # module_cls: RLModule = module_spec.pop("module_class")
+ # module = module_cls.from_model_config(**module_spec)
+ module = module_spec.build()
multiagent_module.add_module(module_id, module)
return multiagent_module
@@ -136,9 +155,8 @@ def from_multi_agent_config(cls, config: Mapping[str, Any]) -> "MultiAgentRLModu
def __check_module_configs(cls, module_configs: Dict[ModuleID, Any]):
"""Checks the module configs for validity.
- The module_configs be a mapping from module_ids to a dict that contains the
- following required keys: `module_class`, `observation_space`, `action_space`,
- `model_config`.
+ The module_configs be a mapping from module_ids to SingleAgentRLModuleSpec
+ objects.
Args:
module_configs: The module configs to check.
@@ -146,19 +164,11 @@ def __check_module_configs(cls, module_configs: Dict[ModuleID, Any]):
Raises:
ValueError: If the module configs are invalid.
"""
- REQUIRED_KEYS = {
- "module_class",
- "observation_space",
- "action_space",
- "model_config",
- }
for module_id, module_spec in module_configs.items():
- for module_key in REQUIRED_KEYS:
- if module_key not in module_spec:
- raise ValueError(
- f"Module config for module_id {module_id} is missing "
- f"required key {module_key}."
- )
+ if not isinstance(module_spec, SingleAgentRLModuleSpec):
+ raise ValueError(
+ f"Module {module_id} is not a SingleAgentRLModuleSpec object."
+ )
def keys(self) -> Iterator[ModuleID]:
"""Returns an iteratable of module ids."""
diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py
index c9c69443eebe6..26dc32ced2a1f 100644
--- a/rllib/core/rl_module/rl_module.py
+++ b/rllib/core/rl_module/rl_module.py
@@ -1,7 +1,7 @@
import abc
from dataclasses import dataclass
import gymnasium as gym
-from typing import Mapping, Any, TYPE_CHECKING, Union
+from typing import Mapping, Any, TYPE_CHECKING, Union, Optional, Type, Dict
if TYPE_CHECKING:
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModule
@@ -24,14 +24,37 @@
ModuleID = str
+@ExperimentalAPI
+@dataclass
+class SingleAgentRLModuleSpec:
+ """A utility spec class to make it constructing RLModules (in single-agent case) easier.
+
+ Args:
+ module_class: ...
+ observation_space: ...
+ action_space: ...
+ model_config: ...
+ """
+
+ module_class: Optional[Type["RLModule"]] = None
+ observation_space: Optional["gym.Space"] = None
+ action_space: Optional["gym.Space"] = None
+ model_config: Optional[Dict[str, Any]] = None
+
+ def build(self) -> "RLModule":
+ return self.module_class.from_model_config(
+ observation_space=self.observation_space,
+ action_space=self.action_space,
+ model_config=self.model_config,
+ )
+
+
@ExperimentalAPI
@dataclass
class RLModuleConfig:
"""Configuration for the PPO module.
-
# TODO (Kourosh): Whether we need this or not really depends on how the catalog
# design end up being.
-
Attributes:
observation_space: The observation space of the environment.
action_space: The action space of the environment.
diff --git a/rllib/core/rl_module/tests/test_marl_module.py b/rllib/core/rl_module/tests/test_marl_module.py
index 3b3c7e5fdb7ea..fa7a2f0305257 100644
--- a/rllib/core/rl_module/tests/test_marl_module.py
+++ b/rllib/core/rl_module/tests/test_marl_module.py
@@ -1,11 +1,13 @@
import unittest
+from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModule, _get_module_configs
from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
from ray.rllib.env.multi_agent_env import make_multi_agent
from ray.rllib.utils.test_utils import check
+
DEFAULT_POLICY_ID = "default_policy"
@@ -39,14 +41,14 @@ def test_from_multi_agent_config(self):
multi_agent_dict = {
"modules": {
- "module1": {
- "module_class": DiscreteBCTorchModule,
- "model_config": {"hidden_dim": 64},
- },
- "module2": {
- "module_class": DiscreteBCTorchModule,
- "model_config": {"hidden_dim": 32},
- },
+ "module1": SingleAgentRLModuleSpec(
+ module_class=DiscreteBCTorchModule,
+ model_config={"hidden_dim": 64},
+ ),
+ "module2": SingleAgentRLModuleSpec(
+ module_class=DiscreteBCTorchModule,
+ model_config={"hidden_dim": 32},
+ ),
},
"observation_space": env.observation_space, # this is common
"action_space": env.action_space, # this is common
@@ -160,57 +162,73 @@ def test_get_module_configs(self):
config = {
"modules": {
- "1": {"module_class": "foo", "model_config": "bar"},
- "2": {"module_class": "foo2", "model_config": "bar2"},
+ "1": SingleAgentRLModuleSpec(
+ **{"module_class": "foo", "model_config": "bar"}
+ ),
+ "2": SingleAgentRLModuleSpec(
+ **{"module_class": "foo2", "model_config": "bar2"}
+ ),
},
"observation_space": "obs_space",
"action_space": "action_space",
}
expected_config = {
- "1": {
- "module_class": "foo",
- "model_config": "bar",
- "observation_space": "obs_space",
- "action_space": "action_space",
- },
- "2": {
- "module_class": "foo2",
- "model_config": "bar2",
- "observation_space": "obs_space",
- "action_space": "action_space",
- },
+ "1": SingleAgentRLModuleSpec(
+ **{
+ "module_class": "foo",
+ "model_config": "bar",
+ "observation_space": "obs_space",
+ "action_space": "action_space",
+ }
+ ),
+ "2": SingleAgentRLModuleSpec(
+ **{
+ "module_class": "foo2",
+ "model_config": "bar2",
+ "observation_space": "obs_space",
+ "action_space": "action_space",
+ }
+ ),
}
self.assertDictEqual(_get_module_configs(config), expected_config)
config = {
"modules": {
- "1": {
- "module_class": "foo",
- "model_config": "bar",
- "observation_space": "obs_space1", # won't get overwritten
- "action_space": "action_space1", # won't get overwritten
- },
- "2": {"module_class": "foo2", "model_config": "bar2"},
+ "1": SingleAgentRLModuleSpec(
+ **{
+ "module_class": "foo",
+ "model_config": "bar",
+ "observation_space": "obs_space1", # won't get overwritten
+ "action_space": "action_space1", # won't get overwritten
+ }
+ ),
+ "2": SingleAgentRLModuleSpec(
+ **{"module_class": "foo2", "model_config": "bar2"}
+ ),
},
"observation_space": "obs_space",
"action_space": "action_space",
}
expected_config = {
- "1": {
- "module_class": "foo",
- "model_config": "bar",
- "observation_space": "obs_space1",
- "action_space": "action_space1",
- },
- "2": {
- "module_class": "foo2",
- "model_config": "bar2",
- "observation_space": "obs_space",
- "action_space": "action_space",
- },
+ "1": SingleAgentRLModuleSpec(
+ **{
+ "module_class": "foo",
+ "model_config": "bar",
+ "observation_space": "obs_space1",
+ "action_space": "action_space1",
+ }
+ ),
+ "2": SingleAgentRLModuleSpec(
+ **{
+ "module_class": "foo2",
+ "model_config": "bar2",
+ "observation_space": "obs_space",
+ "action_space": "action_space",
+ }
+ ),
}
self.assertDictEqual(_get_module_configs(config), expected_config)
diff --git a/rllib/core/rl_trainer/rl_trainer.py b/rllib/core/rl_trainer/rl_trainer.py
index b861687905558..8969cab69a1ec 100644
--- a/rllib/core/rl_trainer/rl_trainer.py
+++ b/rllib/core/rl_trainer/rl_trainer.py
@@ -18,8 +18,15 @@
)
from ray.rllib.utils.framework import try_import_tf, try_import_torch
-from ray.rllib.core.rl_module.rl_module import RLModule, ModuleID
-from ray.rllib.core.rl_module.marl_module import MultiAgentRLModule
+from ray.rllib.core.rl_module.rl_module import (
+ RLModule,
+ ModuleID,
+ SingleAgentRLModuleSpec,
+)
+from ray.rllib.core.rl_module.marl_module import (
+ MultiAgentRLModule,
+ MultiAgentRLModuleSpec,
+)
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.numpy import convert_to_numpy
@@ -105,9 +112,12 @@ class RLTrainer:
def __init__(
self,
- module_class: Union[Type[RLModule], Type[MultiAgentRLModule]],
- module_kwargs: Mapping[str, Any],
- optimizer_config: Mapping[str, Any],
+ *,
+ module_spec: Optional[
+ Union[SingleAgentRLModuleSpec, MultiAgentRLModuleSpec]
+ ] = None,
+ module: Optional[RLModule] = None,
+ optimizer_config: Mapping[str, Any] = None,
distributed: bool = False,
scaling_config: Optional["ScalingConfig"] = None,
algorithm_config: Optional["AlgorithmConfig"] = None,
@@ -117,8 +127,18 @@ def __init__(
# understand it. If we can find a better way to make subset of the config
# available to the trainer, that would be great.
# TODO (Kourosh): convert optimizer configs to dataclasses
- self.module_class = module_class
- self.module_kwargs = module_kwargs
+ if module_spec is not None and module is not None:
+ raise ValueError(
+ "Only one of module spec or module can be provided to RLTrainer."
+ )
+
+ if module_spec is None and module is None:
+ raise ValueError(
+ "Either module_spec or module should be provided to RLTrainer."
+ )
+
+ self.module_spec = module_spec
+ self.module_obj = module
self.optimizer_config = optimizer_config
self.distributed = distributed
self.scaling_config = scaling_config
@@ -415,12 +435,11 @@ def add_module(
self,
*,
module_id: ModuleID,
- module_cls: Type[RLModule],
- module_kwargs: Mapping[str, Any],
+ module_spec: SingleAgentRLModuleSpec,
set_optimizer_fn: Optional[Callable[[RLModule], ParamOptimizerPairs]] = None,
optimizer_cls: Optional[Type[Optimizer]] = None,
) -> None:
- """Add a module to the trainer.
+ """Add a module to the underlying MultiAgentRLModule and the trainer.
Args:
module_id: The id of the module to add.
@@ -435,7 +454,7 @@ def add_module(
should be provided.
"""
self.__check_if_build_called()
- module = module_cls.from_model_config(**module_kwargs)
+ module = module_spec.build()
# construct a default set_optimizer_fn if not provided
if set_optimizer_fn is None:
@@ -492,19 +511,11 @@ def _make_module(self) -> MultiAgentRLModule:
Returns:
The constructed module.
"""
-
- if issubclass(self.module_class, MultiAgentRLModule):
- module = self.module_class.from_multi_agent_config(**self.module_kwargs)
- elif issubclass(self.module_class, RLModule):
- module = self.module_class.from_model_config(
- **self.module_kwargs
- ).as_multi_agent()
+ if self.module_obj is not None:
+ module = self.module_obj
else:
- raise ValueError(
- f"Module class {self.module_class} is not a subclass of "
- f"RLModule or MultiAgentRLModule."
- )
-
+ module = self.module_spec.build()
+ module = module.as_multi_agent()
return module
def build(self) -> None:
diff --git a/rllib/core/rl_trainer/tests/test_rl_trainer.py b/rllib/core/rl_trainer/tests/test_rl_trainer.py
index db071ad2ed83f..8af9b0ae72456 100644
--- a/rllib/core/rl_trainer/tests/test_rl_trainer.py
+++ b/rllib/core/rl_trainer/tests/test_rl_trainer.py
@@ -5,6 +5,7 @@
import ray
+from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.core.rl_trainer.rl_trainer import RLTrainer
from ray.rllib.core.testing.tf.bc_module import DiscreteBCTFModule
from ray.rllib.core.testing.tf.bc_rl_trainer import BCTfRLTrainer
@@ -21,12 +22,12 @@ def get_trainer(distributed=False) -> RLTrainer:
# and internally it will serialize and deserialize the module for distributed
# construction.
trainer = BCTfRLTrainer(
- module_class=DiscreteBCTFModule,
- module_kwargs={
- "observation_space": env.observation_space,
- "action_space": env.action_space,
- "model_config": {"hidden_dim": 32},
- },
+ module_spec=SingleAgentRLModuleSpec(
+ module_class=DiscreteBCTFModule,
+ observation_space=env.observation_space,
+ action_space=env.action_space,
+ model_config={"hidden_dim": 32},
+ ),
optimizer_config={"lr": 1e-3},
distributed=distributed,
)
@@ -125,13 +126,12 @@ def set_optimizer_fn(module):
trainer.add_module(
module_id="test",
- module_cls=DiscreteBCTFModule,
- module_kwargs={
- "observation_space": env.observation_space,
- "action_space": env.action_space,
- # the hidden size is different than the default module
- "model_config": {"hidden_dim": 16},
- },
+ module_spec=SingleAgentRLModuleSpec(
+ module_class=DiscreteBCTFModule,
+ observation_space=env.observation_space,
+ action_space=env.action_space,
+ model_config={"hidden_dim": 16},
+ ),
set_optimizer_fn=set_optimizer_fn,
)
diff --git a/rllib/core/rl_trainer/tests/test_trainer_runner_config.py b/rllib/core/rl_trainer/tests/test_trainer_runner_config.py
index 5c8202a9890ad..acdc677313377 100644
--- a/rllib/core/rl_trainer/tests/test_trainer_runner_config.py
+++ b/rllib/core/rl_trainer/tests/test_trainer_runner_config.py
@@ -2,10 +2,13 @@
import unittest
import ray
+
+from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
+from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.core.rl_trainer.trainer_runner_config import TrainerRunnerConfig
from ray.rllib.core.testing.tf.bc_module import DiscreteBCTFModule
from ray.rllib.core.testing.tf.bc_rl_trainer import BCTfRLTrainer
-from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
+from ray.rllib.core.testing.utils import get_module_spec
class TestAlgorithmConfig(unittest.TestCase):
@@ -24,12 +27,7 @@ def test_trainer_runner_build(self):
config = (
TrainerRunnerConfig()
- .module(
- module_class=DiscreteBCTFModule,
- observation_space=env.observation_space,
- action_space=env.action_space,
- model_config={"hidden_dim": 32},
- )
+ .module(get_module_spec("tf", env))
.trainer(
trainer_class=BCTfRLTrainer,
)
@@ -50,7 +48,9 @@ def test_trainer_runner_build_from_algorithm_config(self):
)
config.freeze()
runner_config = config.get_trainer_runner_config(
- env.observation_space, env.action_space
+ SingleAgentRLModuleSpec(
+ observation_space=env.observation_space, action_space=env.action_space
+ )
)
runner_config.build()
diff --git a/rllib/core/rl_trainer/tests/test_trainer_runner_local.py b/rllib/core/rl_trainer/tests/test_trainer_runner_local.py
index 81b03b3a1ab40..9986cf98dd3d2 100644
--- a/rllib/core/rl_trainer/tests/test_trainer_runner_local.py
+++ b/rllib/core/rl_trainer/tests/test_trainer_runner_local.py
@@ -1,15 +1,16 @@
import gymnasium as gym
import unittest
-from ray.rllib.utils.framework import try_import_tf
import ray
-from ray.rllib.core.rl_trainer.trainer_runner import TrainerRunner
-from ray.rllib.core.testing.tf.bc_module import DiscreteBCTFModule
-from ray.rllib.core.testing.tf.bc_rl_trainer import BCTfRLTrainer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch
from ray.rllib.utils.test_utils import check, get_cartpole_dataset_reader
-from ray.rllib.core.testing.utils import add_module_to_runner_or_trainer
+from ray.rllib.utils.framework import try_import_tf
+from ray.rllib.core.testing.utils import (
+ add_module_to_runner_or_trainer,
+ get_trainer_runner,
+ get_rl_trainer,
+)
tf1, tf, tfv = try_import_tf()
@@ -31,47 +32,35 @@ def tearDown(cls) -> None:
def test_trainer_runner_no_gpus(self):
env = gym.make("CartPole-v1")
- trainer_class = BCTfRLTrainer
- trainer_cfg = dict(
- module_class=DiscreteBCTFModule,
- module_kwargs={
- "observation_space": env.observation_space,
- "action_space": env.action_space,
- "model_config": {"hidden_dim": 32},
- },
- optimizer_config={"lr": 1e-3},
- )
- runner = TrainerRunner(
- trainer_class, trainer_cfg, compute_config=dict(num_gpus=0)
- )
-
- local_trainer = trainer_class(**trainer_cfg)
- local_trainer.build()
-
- # make the state of the trainer and the local runner identical
- local_trainer.set_state(runner.get_state()[0])
-
- reader = get_cartpole_dataset_reader(batch_size=500)
- batch = reader.next()
- batch = batch.as_multi_agent()
- check(local_trainer.update(batch), runner.update(batch)[0])
-
- new_module_id = "test_module"
-
- add_module_to_runner_or_trainer("tf", env, new_module_id, runner)
- add_module_to_runner_or_trainer("tf", env, new_module_id, local_trainer)
-
- # make the state of the trainer and the local runner identical
- local_trainer.set_state(runner.get_state()[0])
-
- # do another update
- batch = reader.next()
- ma_batch = MultiAgentBatch(
- {new_module_id: batch, DEFAULT_POLICY_ID: batch}, env_steps=batch.count
- )
- check(local_trainer.update(ma_batch), runner.update(ma_batch)[0])
-
- check(local_trainer.get_state(), runner.get_state()[0])
+ for fw in ["tf", "torch"]:
+ runner = get_trainer_runner(fw, env, compute_config=dict(num_gpus=0))
+ local_trainer = get_rl_trainer(fw, env)
+ local_trainer.build()
+
+ # make the state of the trainer and the local runner identical
+ local_trainer.set_state(runner.get_state()[0])
+
+ reader = get_cartpole_dataset_reader(batch_size=500)
+ batch = reader.next()
+ batch = batch.as_multi_agent()
+ check(local_trainer.update(batch), runner.update(batch)[0])
+
+ new_module_id = "test_module"
+
+ add_module_to_runner_or_trainer(fw, env, new_module_id, runner)
+ add_module_to_runner_or_trainer(fw, env, new_module_id, local_trainer)
+
+ # make the state of the trainer and the local runner identical
+ local_trainer.set_state(runner.get_state()[0])
+
+ # do another update
+ batch = reader.next()
+ ma_batch = MultiAgentBatch(
+ {new_module_id: batch, DEFAULT_POLICY_ID: batch}, env_steps=batch.count
+ )
+ check(local_trainer.update(ma_batch), runner.update(ma_batch)[0])
+
+ check(local_trainer.get_state(), runner.get_state()[0])
if __name__ == "__main__":
diff --git a/rllib/core/rl_trainer/tf/tf_rl_trainer.py b/rllib/core/rl_trainer/tf/tf_rl_trainer.py
index 442932bc00fbf..44a8f6d215817 100644
--- a/rllib/core/rl_trainer/tf/tf_rl_trainer.py
+++ b/rllib/core/rl_trainer/tf/tf_rl_trainer.py
@@ -14,14 +14,21 @@
from ray.rllib.core.rl_trainer.rl_trainer import (
RLTrainer,
- MultiAgentRLModule,
ParamOptimizerPairs,
ParamRef,
Optimizer,
ParamType,
ParamDictType,
)
-from ray.rllib.core.rl_module.rl_module import RLModule, ModuleID
+from ray.rllib.core.rl_module.rl_module import (
+ RLModule,
+ ModuleID,
+ SingleAgentRLModuleSpec,
+)
+from ray.rllib.core.rl_module.marl_module import (
+ MultiAgentRLModule,
+ MultiAgentRLModuleSpec,
+)
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
@@ -86,8 +93,11 @@ class TfRLTrainer(RLTrainer):
def __init__(
self,
- module_class: Union[Type[RLModule], Type[MultiAgentRLModule]],
- module_kwargs: Mapping[str, Any],
+ *,
+ module_spec: Optional[
+ Union[SingleAgentRLModuleSpec, MultiAgentRLModuleSpec]
+ ] = None,
+ module: Optional[RLModule] = None,
optimizer_config: Mapping[str, Any],
distributed: bool = False,
enable_tf_function: bool = True,
@@ -95,8 +105,8 @@ def __init__(
algorithm_config: Optional["AlgorithmConfig"] = None,
):
super().__init__(
- module_class=module_class,
- module_kwargs=module_kwargs,
+ module_spec=module_spec,
+ module=module,
optimizer_config=optimizer_config,
distributed=distributed,
scaling_config=scaling_config,
@@ -195,8 +205,7 @@ def add_module(
self,
*,
module_id: ModuleID,
- module_cls: Type[RLModule],
- module_kwargs: Mapping[str, Any],
+ module_spec: SingleAgentRLModuleSpec,
set_optimizer_fn: Optional[Callable[[RLModule], ParamOptimizerPairs]] = None,
optimizer_cls: Optional[Type[Optimizer]] = None,
) -> None:
@@ -204,16 +213,14 @@ def add_module(
with self.strategy.scope():
super().add_module(
module_id=module_id,
- module_cls=module_cls,
- module_kwargs=module_kwargs,
+ module_spec=module_spec,
set_optimizer_fn=set_optimizer_fn,
optimizer_cls=optimizer_cls,
)
else:
super().add_module(
module_id=module_id,
- module_cls=module_cls,
- module_kwargs=module_kwargs,
+ module_spec=module_spec,
set_optimizer_fn=set_optimizer_fn,
optimizer_cls=optimizer_cls,
)
diff --git a/rllib/core/rl_trainer/torch/tests/test_torch_rl_trainer.py b/rllib/core/rl_trainer/torch/tests/test_torch_rl_trainer.py
index 04835b86dbf5e..40806048b6dcf 100644
--- a/rllib/core/rl_trainer/torch/tests/test_torch_rl_trainer.py
+++ b/rllib/core/rl_trainer/torch/tests/test_torch_rl_trainer.py
@@ -5,6 +5,7 @@
import ray
+from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.core.rl_trainer.rl_trainer import RLTrainer
from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
from ray.rllib.core.testing.torch.bc_rl_trainer import BCTorchRLTrainer
@@ -26,12 +27,12 @@ def _get_trainer(scaling_config=None, distributed: bool = False) -> RLTrainer:
# and internally it will serialize and deserialize the module for distributed
# construction.
trainer = BCTorchRLTrainer(
- module_class=DiscreteBCTorchModule,
- module_kwargs={
- "observation_space": env.observation_space,
- "action_space": env.action_space,
- "model_config": {"hidden_dim": 32},
- },
+ module_spec=SingleAgentRLModuleSpec(
+ module_class=DiscreteBCTorchModule,
+ observation_space=env.observation_space,
+ action_space=env.action_space,
+ model_config={"hidden_dim": 32},
+ ),
scaling_config=scaling_config,
optimizer_config={"lr": 1e-3},
distributed=distributed,
@@ -129,13 +130,12 @@ def set_optimizer_fn(module):
trainer.add_module(
module_id="test",
- module_cls=DiscreteBCTorchModule,
- module_kwargs={
- "observation_space": env.observation_space,
- "action_space": env.action_space,
- # the hidden size is different than the default module
- "model_config": {"hidden_dim": 16},
- },
+ module_spec=SingleAgentRLModuleSpec(
+ module_class=DiscreteBCTorchModule,
+ observation_space=env.observation_space,
+ action_space=env.action_space,
+ model_config={"hidden_dim": 16},
+ ),
set_optimizer_fn=set_optimizer_fn,
)
diff --git a/rllib/core/rl_trainer/torch/torch_rl_trainer.py b/rllib/core/rl_trainer/torch/torch_rl_trainer.py
index 43bf72c731217..b8313b47a5843 100644
--- a/rllib/core/rl_trainer/torch/torch_rl_trainer.py
+++ b/rllib/core/rl_trainer/torch/torch_rl_trainer.py
@@ -11,10 +11,17 @@
TYPE_CHECKING,
)
-from ray.rllib.core.rl_module.rl_module import RLModule, ModuleID
+from ray.rllib.core.rl_module.rl_module import (
+ RLModule,
+ ModuleID,
+ SingleAgentRLModuleSpec,
+)
+from ray.rllib.core.rl_module.marl_module import (
+ MultiAgentRLModule,
+ MultiAgentRLModuleSpec,
+)
from ray.rllib.core.rl_trainer.rl_trainer import (
RLTrainer,
- MultiAgentRLModule,
ParamOptimizerPairs,
Optimizer,
ParamType,
@@ -45,16 +52,19 @@ class TorchRLTrainer(RLTrainer):
def __init__(
self,
- module_class: Union[Type[RLModule], Type[MultiAgentRLModule]],
- module_kwargs: Mapping[str, Any],
+ *,
+ module_spec: Optional[
+ Union[SingleAgentRLModuleSpec, MultiAgentRLModuleSpec]
+ ] = None,
+ module: Optional[RLModule] = None,
optimizer_config: Mapping[str, Any],
distributed: bool = False,
scaling_config: Optional["ScalingConfig"] = None,
algorithm_config: Optional["AlgorithmConfig"] = None,
):
super().__init__(
- module_class=module_class,
- module_kwargs=module_kwargs,
+ module_spec=module_spec,
+ module=module,
optimizer_config=optimizer_config,
distributed=distributed,
scaling_config=scaling_config,
@@ -183,15 +193,13 @@ def add_module(
self,
*,
module_id: ModuleID,
- module_cls: Type[RLModule],
- module_kwargs: Mapping[str, Any],
+ module_spec: SingleAgentRLModuleSpec,
set_optimizer_fn: Optional[Callable[[RLModule], ParamOptimizerPairs]] = None,
optimizer_cls: Optional[Type[Optimizer]] = None,
) -> None:
super().add_module(
module_id=module_id,
- module_cls=module_cls,
- module_kwargs=module_kwargs,
+ module_spec=module_spec,
set_optimizer_fn=set_optimizer_fn,
optimizer_cls=optimizer_cls,
)
diff --git a/rllib/core/rl_trainer/trainer_runner.py b/rllib/core/rl_trainer/trainer_runner.py
index 59fba884ab7d4..c6fcda3deacfb 100644
--- a/rllib/core/rl_trainer/trainer_runner.py
+++ b/rllib/core/rl_trainer/trainer_runner.py
@@ -3,7 +3,11 @@
import ray
-from ray.rllib.core.rl_module.rl_module import RLModule, ModuleID
+from ray.rllib.core.rl_module.rl_module import (
+ RLModule,
+ ModuleID,
+ SingleAgentRLModuleSpec,
+)
from ray.rllib.core.rl_trainer.rl_trainer import (
RLTrainer,
ParamOptimizerPairs,
@@ -165,8 +169,7 @@ def add_module(
self,
*,
module_id: ModuleID,
- module_cls: Type[RLModule],
- module_kwargs: Mapping[str, Any],
+ module_spec: SingleAgentRLModuleSpec,
set_optimizer_fn: Optional[Callable[[RLModule], ParamOptimizerPairs]] = None,
optimizer_cls: Optional[Type[Optimizer]] = None,
) -> None:
@@ -174,8 +177,7 @@ def add_module(
Args:
module_id: The id of the module to add.
- module_cls: The module class to add.
- module_kwargs: The config for the module.
+ module_spec: #TODO (Kourosh) fill in here.
set_optimizer_fn: A function that takes in the module and returns a list of
(param, optimizer) pairs. Each element in the tuple describes a
parameter group that share the same optimizer object, if None, the
@@ -189,8 +191,7 @@ def add_module(
for worker in self._workers:
ref = worker.add_module.remote(
module_id=module_id,
- module_cls=module_cls,
- module_kwargs=module_kwargs,
+ module_spec=module_spec,
set_optimizer_fn=set_optimizer_fn,
optimizer_cls=optimizer_cls,
)
@@ -199,8 +200,7 @@ def add_module(
else:
self._trainer.add_module(
module_id=module_id,
- module_cls=module_cls,
- module_kwargs=module_kwargs,
+ module_spec=module_spec,
set_optimizer_fn=set_optimizer_fn,
optimizer_cls=optimizer_cls,
)
diff --git a/rllib/core/rl_trainer/trainer_runner_config.py b/rllib/core/rl_trainer/trainer_runner_config.py
index 4b04bead3fce8..d193e7cbb1f86 100644
--- a/rllib/core/rl_trainer/trainer_runner_config.py
+++ b/rllib/core/rl_trainer/trainer_runner_config.py
@@ -1,12 +1,14 @@
from typing import Type, Optional, TYPE_CHECKING, Union, Dict
+from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
+from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.utils.from_config import NotProvided
from ray.rllib.core.rl_trainer.trainer_runner import TrainerRunner
if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
- from ray.rllib.core.rl_module import RLModule
from ray.rllib.core.rl_trainer import RLTrainer
- import gymnasium as gym
+
+ModuleSpec = Union[SingleAgentRLModuleSpec, MultiAgentRLModuleSpec]
# TODO (Kourosh): We should make all configs come from a standard base class that
@@ -20,11 +22,7 @@ def __init__(self, cls: Type[TrainerRunner] = None) -> None:
self.trainer_runner_class = cls or TrainerRunner
# `self.module()`
- self.module_obj = None
- self.module_class = None
- self.observation_space = None
- self.action_space = None
- self.model_config = None
+ self.module_spec = None
# `self.trainer()`
self.trainer_class = None
@@ -40,30 +38,12 @@ def __init__(self, cls: Type[TrainerRunner] = None) -> None:
def validate(self) -> None:
- if self.module_class is None and self.module_obj is None:
+ if self.module_spec is None:
raise ValueError(
- "Cannot initialize an RLTrainer without an RLModule. Please provide "
- "the RLModule class with .module(module_class=MyModuleClass) or "
- "an RLModule instance with .module(module=MyModuleInstance)."
+ "Cannot initialize an RLTrainer without the module specs. "
+ "Please provide the specs via .module(module_spec)."
)
- if self.module_class is not None:
- if self.observation_space is None:
- raise ValueError(
- "Must provide observation_space for RLModule when RLModule class "
- "is provided. Use .module(observation_space=MySpace)."
- )
- if self.action_space is None:
- raise ValueError(
- "Must provide action_space for RLModule when RLModule class "
- "is provided. Use .module(action_space=MySpace)."
- )
- if self.model_config is None:
- raise ValueError(
- "Must provide model_config for RLModule when RLModule class "
- "is provided. Use .module(model_config=MyConfig)."
- )
-
if self.trainer_class is None:
raise ValueError(
"Cannot initialize an RLTrainer without an RLTrainer. Please provide "
@@ -86,17 +66,16 @@ def validate(self) -> None:
def build(self) -> TrainerRunner:
self.validate()
+
+ # If the module class is a multi agent class it will override the default
+ # MultiAgentRLModule class. otherwise, it will be a single agent wrapped with
+ # mutliagent
# TODO (Kourosh): What should be scaling_config? it's not clear what
# should be passed in as trainer_config and what will be inferred
return self.trainer_runner_class(
trainer_class=self.trainer_class,
trainer_config={
- "module_class": self.module_class,
- "module_kwargs": {
- "observation_space": self.observation_space,
- "action_space": self.action_space,
- "model_config": self.model_config,
- },
+ "module_spec": self.module_spec,
# TODO (Kourosh): should this be inferred inside the constructor?
"distributed": self.num_gpus > 1,
# TODO (Avnish): add this
@@ -120,31 +99,24 @@ def algorithm(
def module(
self,
- *,
- module_class: Optional[Type["RLModule"]] = NotProvided,
- observation_space: Optional["gym.Space"] = NotProvided,
- action_space: Optional["gym.Space"] = NotProvided,
- model_config: Optional[dict] = NotProvided,
- module: Optional["RLModule"] = NotProvided,
+ module_spec: Optional[ModuleSpec] = NotProvided,
) -> "TrainerRunnerConfig":
- if module is NotProvided and module_class is NotProvided:
- raise ValueError(
- "Must provide either module or module_class. Please provide "
- "the RLModule class with .module(module=MyModule) or "
- ".module(module_class=MyModuleClass)."
- )
+ if module_spec is not NotProvided:
+ self.module_spec = module_spec
+
+ return self
- if module_class is not NotProvided:
- self.module_class = module_class
- if observation_space is not NotProvided:
- self.observation_space = observation_space
- if action_space is not NotProvided:
- self.action_space = action_space
- if model_config is not NotProvided:
- self.model_config = model_config
- if module is not NotProvided:
- self.module_obj = module
+ def resources(
+ self,
+ num_gpus: Optional[Union[float, int]] = NotProvided,
+ fake_gpus: Optional[bool] = NotProvided,
+ ) -> "TrainerRunnerConfig":
+
+ if num_gpus is not NotProvided:
+ self.num_gpus = num_gpus
+ if fake_gpus is not NotProvided:
+ self.fake_gpus = fake_gpus
return self
@@ -164,16 +136,3 @@ def trainer(
self.optimizer_config = optimizer_config
return self
-
- def resources(
- self,
- num_gpus: Optional[Union[float, int]] = NotProvided,
- fake_gpus: Optional[bool] = NotProvided,
- ) -> "TrainerRunnerConfig":
-
- if num_gpus is not NotProvided:
- self.num_gpus = num_gpus
- if fake_gpus is not NotProvided:
- self.fake_gpus = fake_gpus
-
- return self
diff --git a/rllib/core/testing/utils.py b/rllib/core/testing/utils.py
index 95e24758d62f3..bd96492bac501 100644
--- a/rllib/core/testing/utils.py
+++ b/rllib/core/testing/utils.py
@@ -1,9 +1,16 @@
from typing import Type, Union, TYPE_CHECKING
+from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.core.rl_trainer.trainer_runner import TrainerRunner
+from ray.rllib.core.rl_module.marl_module import (
+ MultiAgentRLModuleSpec,
+ MultiAgentRLModule,
+)
+from ray.rllib.core.rl_module.tests.test_marl_module import DEFAULT_POLICY_ID
+
if TYPE_CHECKING:
import gymnasium as gym
import torch
@@ -44,6 +51,26 @@ def get_module_class(framework: str) -> Type["RLModule"]:
raise ValueError(f"Unsupported framework: {framework}")
+@DeveloperAPI
+def get_module_spec(framework: str, env: "gym.Env", is_multi_agent: bool = False):
+
+ spec = SingleAgentRLModuleSpec(
+ module_class=get_module_class(framework),
+ observation_space=env.observation_space,
+ action_space=env.action_space,
+ model_config={"hidden_dim": 32},
+ )
+
+ if is_multi_agent:
+ # TODO (Kourosh): Make this more multi-agent for example with policy ids "1",
+ # and "2".
+ return MultiAgentRLModuleSpec(
+ module_class=MultiAgentRLModule, module_specs={DEFAULT_POLICY_ID: spec}
+ )
+ else:
+ return spec
+
+
@DeveloperAPI
def get_optimizer_default_class(framework: str) -> Type[Optimizer]:
if framework == "tf":
@@ -58,18 +85,30 @@ def get_optimizer_default_class(framework: str) -> Type[Optimizer]:
raise ValueError(f"Unsupported framework: {framework}")
+@DeveloperAPI
+def get_rl_trainer(
+ framework: str,
+ env: "gym.Env",
+ is_multi_agent: bool = False,
+) -> "RLTrainer":
+
+ _cls = get_trainer_class(framework)
+ spec = get_module_spec(framework=framework, env=env, is_multi_agent=is_multi_agent)
+ return _cls(module_spec=spec, optimizer_config={"lr": 0.1})
+
+
@DeveloperAPI
def get_trainer_runner(
- framework: str, env: "gym.Env", compute_config: dict
+ framework: str,
+ env: "gym.Env",
+ compute_config: dict,
+ is_multi_agent: bool = False,
) -> TrainerRunner:
trainer_class = get_trainer_class(framework)
trainer_cfg = dict(
- module_class=get_module_class(framework),
- module_kwargs={
- "observation_space": env.observation_space,
- "action_space": env.action_space,
- "model_config": {"hidden_dim": 32},
- },
+ module_spec=get_module_spec(
+ framework=framework, env=env, is_multi_agent=is_multi_agent
+ ),
optimizer_config={"lr": 0.1},
)
runner = TrainerRunner(trainer_class, trainer_cfg, compute_config=compute_config)
@@ -86,11 +125,6 @@ def add_module_to_runner_or_trainer(
):
runner_or_trainer.add_module(
module_id=module_id,
- module_cls=get_module_class(framework),
- module_kwargs={
- "observation_space": env.observation_space,
- "action_space": env.action_space,
- "model_config": {"hidden_dim": 32},
- },
+ module_spec=get_module_spec(framework, env, is_multi_agent=False),
optimizer_cls=get_optimizer_default_class(framework),
)
diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py
index 5be93a9c1a205..dd78c4536f051 100644
--- a/rllib/env/multi_agent_env.py
+++ b/rllib/env/multi_agent_env.py
@@ -84,7 +84,8 @@ def reset(
"traffic_light_1": [0, 3, 5, 1],
}
"""
- raise NotImplementedError
+ # Call super's `reset()` method to (maybe) set the given `seed`.
+ super().reset(seed=seed, options=options)
@PublicAPI
def step(
diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py
index 16e2782c810ce..6c876c0fdca52 100644
--- a/rllib/evaluation/rollout_worker.py
+++ b/rllib/evaluation/rollout_worker.py
@@ -521,6 +521,9 @@ def gen_rollouts():
):
tf1.enable_eager_execution()
+ if self.config.log_level:
+ logging.getLogger("ray.rllib").setLevel(self.config.log_level)
+
if self.worker_index > 1:
disable_log_once_globally() # only need 1 worker to log
elif self.config.log_level == "DEBUG":
diff --git a/rllib/examples/env/multi_agent.py b/rllib/examples/env/multi_agent.py
index dbbcad5a63ce1..60260910bac5d 100644
--- a/rllib/examples/env/multi_agent.py
+++ b/rllib/examples/env/multi_agent.py
@@ -36,6 +36,10 @@ def __init__(self, num):
self.resetted = False
def reset(self, *, seed=None, options=None):
+ # Call super's `reset()` method to set the np_random with the value of `seed`.
+ # Note: This call to super does NOT return anything.
+ super().reset(seed=seed)
+
self.resetted = True
self.terminateds = set()
self.truncateds = set()
diff --git a/rllib/examples/env/simple_rpg.py b/rllib/examples/env/simple_rpg.py
index 9544d93f6fb38..7de7390bd96dd 100644
--- a/rllib/examples/env/simple_rpg.py
+++ b/rllib/examples/env/simple_rpg.py
@@ -43,7 +43,7 @@ def __init__(self, config):
self.observation_space = Repeated(self.player_space, max_len=MAX_PLAYERS)
def reset(self, *, seed=None, options=None):
- return self.observation_space.sample()
+ return self.observation_space.sample(), {}
def step(self, action):
return self.observation_space.sample(), 1, True, False, {}
diff --git a/rllib/examples/simulators/sumo/connector.py b/rllib/examples/simulators/sumo/connector.py
index 2d30e14392dc8..f74380e29bf55 100644
--- a/rllib/examples/simulators/sumo/connector.py
+++ b/rllib/examples/simulators/sumo/connector.py
@@ -20,6 +20,7 @@
###############################################################################
+logging.basicConfig()
logger = logging.getLogger(__name__)
###############################################################################
diff --git a/rllib/examples/simulators/sumo/utils.py b/rllib/examples/simulators/sumo/utils.py
index a574635c06145..b9f75552351de 100644
--- a/rllib/examples/simulators/sumo/utils.py
+++ b/rllib/examples/simulators/sumo/utils.py
@@ -29,6 +29,7 @@
###############################################################################
+logging.basicConfig()
logger = logging.getLogger(__name__)
###############################################################################
diff --git a/rllib/examples/sumo_env_local.py b/rllib/examples/sumo_env_local.py
index 4f437af98a945..47f888e631fd0 100644
--- a/rllib/examples/sumo_env_local.py
+++ b/rllib/examples/sumo_env_local.py
@@ -23,7 +23,8 @@
from ray.rllib.examples.simulators.sumo import marlenvironment
from ray.rllib.utils.test_utils import check_learning_achieved
-logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.WARN)
+logger = logging.getLogger("ppotrain")
parser = argparse.ArgumentParser()
parser.add_argument(
diff --git a/rllib/examples/tune/framework.py b/rllib/examples/tune/framework.py
index 78a42208f509b..304b549708e93 100644
--- a/rllib/examples/tune/framework.py
+++ b/rllib/examples/tune/framework.py
@@ -10,7 +10,8 @@
from ray.rllib.algorithms.appo import APPOConfig
from ray.tune import CLIReporter
-logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.WARN)
+logger = logging.getLogger("tune_framework")
def run(smoke_test=False):
diff --git a/src/mock/ray/core_worker/task_manager.h b/src/mock/ray/core_worker/task_manager.h
index 31b8dc39bbc31..7795744b37ee8 100644
--- a/src/mock/ray/core_worker/task_manager.h
+++ b/src/mock/ray/core_worker/task_manager.h
@@ -59,7 +59,7 @@ class MockTaskFinisherInterface : public TaskFinisherInterface {
MOCK_METHOD(void, MarkDependenciesResolved, (const TaskID &task_id), (override));
MOCK_METHOD(void,
MarkTaskWaitingForExecution,
- (const TaskID &task_id, const NodeID &node_id),
+ (const TaskID &task_id, const NodeID &node_id, const WorkerID &worker_id),
(override));
};
diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h
index 04373d1caf6d8..fd01f2bb04719 100644
--- a/src/ray/common/ray_config_def.h
+++ b/src/ray/common/ray_config_def.h
@@ -466,6 +466,11 @@ RAY_CONFIG(int64_t, task_events_max_num_task_events_in_buffer, 10000)
/// Setting the value to -1 allows unlimited profile events to be sent.
RAY_CONFIG(int64_t, task_events_max_num_profile_events_for_task, 100)
+/// The delay in ms that GCS should mark any running tasks from a job as failed.
+/// Setting this value too smaller might result in some finished tasks marked as failed by
+/// GCS.
+RAY_CONFIG(uint64_t, gcs_mark_task_failed_on_job_done_delay_ms, /* 15 secs */ 1000 * 15)
+
/// Whether or not we enable metrics collection.
RAY_CONFIG(bool, enable_metrics_collection, true)
diff --git a/src/ray/common/task/task.cc b/src/ray/common/task/task.cc
index 734ad527a16d2..e2ac8571c4e5e 100644
--- a/src/ray/common/task/task.cc
+++ b/src/ray/common/task/task.cc
@@ -26,12 +26,19 @@ RayTask::RayTask(TaskSpecification task_spec) : task_spec_(std::move(task_spec))
ComputeDependencies();
}
+RayTask::RayTask(TaskSpecification task_spec, std::string preferred_node_id)
+ : task_spec_(std::move(task_spec)), preferred_node_id_(std::move(preferred_node_id)) {
+ ComputeDependencies();
+}
+
const TaskSpecification &RayTask::GetTaskSpecification() const { return task_spec_; }
const std::vector &RayTask::GetDependencies() const {
return dependencies_;
}
+const std::string &RayTask::GetPreferredNodeID() const { return preferred_node_id_; }
+
void RayTask::ComputeDependencies() { dependencies_ = task_spec_.GetDependencies(); }
std::string RayTask::DebugString() const {
diff --git a/src/ray/common/task/task.h b/src/ray/common/task/task.h
index ca5fb324b765a..5a4a9e323de53 100644
--- a/src/ray/common/task/task.h
+++ b/src/ray/common/task/task.h
@@ -43,6 +43,8 @@ class RayTask {
/// Construct a `RayTask` object from a `TaskSpecification`.
RayTask(TaskSpecification task_spec);
+ RayTask(TaskSpecification task_spec, std::string preferred_node_id);
+
/// Get the immutable specification for the task.
///
/// \return The immutable specification for the task.
@@ -54,6 +56,12 @@ class RayTask {
/// \return The object dependencies.
const std::vector &GetDependencies() const;
+ /// Get the task's preferred node id for scheduling. If the returned value
+ /// is empty, then it means the task has no preferred node.
+ ///
+ /// \return The preferred node id.
+ const std::string &GetPreferredNodeID() const;
+
std::string DebugString() const;
private:
@@ -66,6 +74,8 @@ class RayTask {
/// A cached copy of the task's object dependencies, including arguments from
/// the TaskSpecification.
std::vector dependencies_;
+
+ std::string preferred_node_id_;
};
} // namespace ray
diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc
index 993d4a6280476..61a10cc5f8476 100644
--- a/src/ray/core_worker/core_worker.cc
+++ b/src/ray/core_worker/core_worker.cc
@@ -790,11 +790,16 @@ void CoreWorker::RegisterToGcs() {
}
auto worker_data = std::make_shared();
+ worker_data->mutable_worker_address()->set_raylet_id(rpc_address_.raylet_id());
+ worker_data->mutable_worker_address()->set_ip_address(rpc_address_.ip_address());
+ worker_data->mutable_worker_address()->set_port(rpc_address_.port());
worker_data->mutable_worker_address()->set_worker_id(worker_id.Binary());
worker_data->set_worker_type(options_.worker_type);
worker_data->mutable_worker_info()->insert(worker_info.begin(), worker_info.end());
+
worker_data->set_is_alive(true);
worker_data->set_pid(getpid());
+ worker_data->set_start_time_ms(current_sys_time_ms());
RAY_CHECK_OK(gcs_client_->Workers().AsyncAdd(worker_data, nullptr));
}
diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc
index 90b8db876c52d..a3839592c9f59 100644
--- a/src/ray/core_worker/task_manager.cc
+++ b/src/ray/core_worker/task_manager.cc
@@ -783,7 +783,8 @@ void TaskManager::MarkDependenciesResolved(const TaskID &task_id) {
}
void TaskManager::MarkTaskWaitingForExecution(const TaskID &task_id,
- const NodeID &node_id) {
+ const NodeID &node_id,
+ const WorkerID &worker_id) {
absl::MutexLock lock(&mu_);
auto it = submissible_tasks_.find(task_id);
if (it == submissible_tasks_.end()) {
@@ -796,7 +797,8 @@ void TaskManager::MarkTaskWaitingForExecution(const TaskID &task_id,
it->second.spec,
rpc::TaskStatus::SUBMITTED_TO_WORKER,
/* include_task_info */ false,
- node_id);
+ node_id,
+ worker_id);
}
void TaskManager::MarkTaskRetryOnResubmit(TaskEntry &task_entry) {
@@ -862,6 +864,10 @@ rpc::TaskInfoEntry TaskManager::MakeTaskInfoEntry(
task_info.mutable_required_resources()->insert(resources_map.begin(),
resources_map.end());
task_info.mutable_runtime_env_info()->CopyFrom(task_spec.RuntimeEnvInfo());
+ const auto &pg_id = task_spec.PlacementGroupBundleId().first;
+ if (!pg_id.IsNil()) {
+ task_info.set_placement_group_id(pg_id.Binary());
+ }
return task_info;
}
@@ -921,7 +927,8 @@ void TaskManager::RecordTaskStatusEvent(int32_t attempt_number,
const TaskSpecification &spec,
rpc::TaskStatus status,
bool include_task_info,
- absl::optional node_id) {
+ absl::optional node_id,
+ absl::optional worker_id) {
if (!task_event_buffer_.Enabled()) {
return;
}
@@ -942,6 +949,12 @@ void TaskManager::RecordTaskStatusEvent(int32_t attempt_number,
<< "Node ID should be included when task status changes to SUBMITTED_TO_WORKER.";
state_updates->set_node_id(node_id->Binary());
}
+ if (worker_id.has_value()) {
+ RAY_CHECK(status == rpc::TaskStatus::SUBMITTED_TO_WORKER)
+ << "Worker ID should be included when task status changes to "
+ "SUBMITTED_TO_WORKER.";
+ state_updates->set_worker_id(worker_id->Binary());
+ }
gcs::FillTaskStatusUpdateTime(status, absl::GetCurrentTimeNanos(), state_updates);
task_event_buffer_.AddTaskEvent(std::move(task_event));
}
diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h
index ab4fb039073d9..b50c58dd26b11 100644
--- a/src/ray/core_worker/task_manager.h
+++ b/src/ray/core_worker/task_manager.h
@@ -53,7 +53,8 @@ class TaskFinisherInterface {
bool fail_immediately = false) = 0;
virtual void MarkTaskWaitingForExecution(const TaskID &task_id,
- const NodeID &node_id) = 0;
+ const NodeID &node_id,
+ const WorkerID &worker_id) = 0;
virtual void OnTaskDependenciesInlined(
const std::vector &inlined_dependency_ids,
@@ -288,7 +289,10 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
///
/// \param[in] task_id The task that is will be running.
/// \param[in] node_id The node id that this task wil be running.
- void MarkTaskWaitingForExecution(const TaskID &task_id, const NodeID &node_id) override;
+ /// \param[in] worker_id The worker id that this task wil be running.
+ void MarkTaskWaitingForExecution(const TaskID &task_id,
+ const NodeID &node_id,
+ const WorkerID &worker_id) override;
/// Add debug information about the current task status for the ObjectRefs
/// included in the given stats.
@@ -321,11 +325,14 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
/// \param include_task_info True if TaskInfoEntry will be added to the Task events.
/// \param node_id Node ID of the worker for which the task's submitted. Only applicable
/// for SUBMITTED_TO_WORKER status change.
+ /// \param worker_id Worker ID of the worker for which the task's submitted. Only
+ /// applicable for SUBMITTED_TO_WORKER status change.
void RecordTaskStatusEvent(int32_t attempt_number,
const TaskSpecification &spec,
rpc::TaskStatus status,
bool include_task_info = false,
- absl::optional node_id = absl::nullopt);
+ absl::optional node_id = absl::nullopt,
+ absl::optional worker_id = absl::nullopt);
private:
struct TaskEntry {
diff --git a/src/ray/core_worker/test/dependency_resolver_test.cc b/src/ray/core_worker/test/dependency_resolver_test.cc
index 1299e626485df..d83318a1e82b9 100644
--- a/src/ray/core_worker/test/dependency_resolver_test.cc
+++ b/src/ray/core_worker/test/dependency_resolver_test.cc
@@ -106,7 +106,8 @@ class MockTaskFinisher : public TaskFinisherInterface {
void MarkDependenciesResolved(const TaskID &task_id) override {}
void MarkTaskWaitingForExecution(const TaskID &task_id,
- const NodeID &node_id) override {}
+ const NodeID &node_id,
+ const WorkerID &worker_id) override {}
int num_tasks_complete = 0;
int num_tasks_failed = 0;
diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc
index 79db6c9bc2d7b..bc90f82c1824a 100644
--- a/src/ray/core_worker/test/direct_task_transport_test.cc
+++ b/src/ray/core_worker/test/direct_task_transport_test.cc
@@ -146,7 +146,8 @@ class MockTaskFinisher : public TaskFinisherInterface {
void MarkDependenciesResolved(const TaskID &task_id) override {}
void MarkTaskWaitingForExecution(const TaskID &task_id,
- const NodeID &node_id) override {}
+ const NodeID &node_id,
+ const WorkerID &worker_id) override {}
int num_tasks_complete = 0;
int num_tasks_failed = 0;
diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc
index db36050d2a43a..3bdb156b3eae4 100644
--- a/src/ray/core_worker/test/task_manager_test.cc
+++ b/src/ray/core_worker/test/task_manager_test.cc
@@ -143,7 +143,8 @@ TEST_F(TaskManagerTest, TestTaskSuccess) {
manager_.MarkDependenciesResolved(spec.TaskId());
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_FALSE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
rpc::PushTaskReply reply;
auto return_object = reply.add_return_objects();
@@ -222,7 +223,8 @@ TEST_F(TaskManagerTest, TestPlasmaConcurrentFailure) {
manager_.MarkDependenciesResolved(spec.TaskId());
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_FALSE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
rpc::PushTaskReply reply;
auto return_object = reply.add_return_objects();
@@ -536,7 +538,8 @@ TEST_F(TaskManagerTest, TestLineageEvicted) {
manager_.MarkDependenciesResolved(spec.TaskId());
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_FALSE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
auto return_id = spec.ReturnId(0);
rpc::PushTaskReply reply;
@@ -605,7 +608,8 @@ TEST_F(TaskManagerLineageTest, TestLineagePinned) {
manager_.MarkDependenciesResolved(spec.TaskId());
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_FALSE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
rpc::PushTaskReply reply;
auto return_object = reply.add_return_objects();
@@ -648,7 +652,8 @@ TEST_F(TaskManagerLineageTest, TestDirectObjectNoLineage) {
manager_.MarkDependenciesResolved(spec.TaskId());
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_FALSE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
rpc::PushTaskReply reply;
auto return_object = reply.add_return_objects();
@@ -694,7 +699,8 @@ TEST_F(TaskManagerLineageTest, TestLineagePinnedOutOfOrder) {
manager_.MarkDependenciesResolved(spec.TaskId());
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_FALSE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
rpc::PushTaskReply reply;
auto return_object = reply.add_return_objects();
@@ -727,7 +733,8 @@ TEST_F(TaskManagerLineageTest, TestRecursiveLineagePinned) {
manager_.MarkDependenciesResolved(spec.TaskId());
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_FALSE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
rpc::PushTaskReply reply;
auto return_object = reply.add_return_objects();
@@ -773,7 +780,8 @@ TEST_F(TaskManagerLineageTest, TestRecursiveDirectObjectNoLineage) {
manager_.MarkDependenciesResolved(spec.TaskId());
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_FALSE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
rpc::PushTaskReply reply;
auto return_object = reply.add_return_objects();
@@ -826,7 +834,8 @@ TEST_F(TaskManagerLineageTest, TestResubmitTask) {
ASSERT_TRUE(reference_counter_->IsObjectPendingCreation(return_id));
// The task completes.
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
rpc::PushTaskReply reply;
auto return_object = reply.add_return_objects();
@@ -884,7 +893,8 @@ TEST_F(TaskManagerLineageTest, TestResubmittedTaskNondeterministicReturns) {
// The task completes. Both return objects are stored in plasma.
{
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
rpc::PushTaskReply reply;
auto return_object1 = reply.add_return_objects();
@@ -916,7 +926,8 @@ TEST_F(TaskManagerLineageTest, TestResubmittedTaskNondeterministicReturns) {
manager_.MarkDependenciesResolved(spec.TaskId());
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_FALSE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
rpc::PushTaskReply reply;
auto return_object1 = reply.add_return_objects();
@@ -948,7 +959,8 @@ TEST_F(TaskManagerLineageTest, TestResubmittedTaskFails) {
// The task completes. One return object is stored in plasma.
{
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
rpc::PushTaskReply reply;
auto return_object1 = reply.add_return_objects();
@@ -978,7 +990,8 @@ TEST_F(TaskManagerLineageTest, TestResubmittedTaskFails) {
manager_.MarkDependenciesResolved(spec.TaskId());
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_FALSE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
manager_.FailOrRetryPendingTask(spec.TaskId(), rpc::ErrorType::WORKER_DIED);
@@ -1000,7 +1013,8 @@ TEST_F(TaskManagerLineageTest, TestDynamicReturnsTask) {
// The task completes and returns dynamic returns.
{
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
rpc::PushTaskReply reply;
auto return_object = reply.add_return_objects();
@@ -1061,7 +1075,8 @@ TEST_F(TaskManagerLineageTest, TestResubmittedDynamicReturnsTaskFails) {
// The task completes and returns dynamic returns.
{
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
rpc::PushTaskReply reply;
auto return_object = reply.add_return_objects();
@@ -1102,7 +1117,8 @@ TEST_F(TaskManagerLineageTest, TestResubmittedDynamicReturnsTaskFails) {
manager_.MarkDependenciesResolved(spec.TaskId());
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_FALSE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
- manager_.MarkTaskWaitingForExecution(spec.TaskId(), NodeID::FromRandom());
+ manager_.MarkTaskWaitingForExecution(
+ spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom());
ASSERT_TRUE(manager_.IsTaskWaitingForExecution(spec.TaskId()));
manager_.FailOrRetryPendingTask(spec.TaskId(), rpc::ErrorType::WORKER_DIED);
diff --git a/src/ray/core_worker/transport/direct_actor_task_submitter.cc b/src/ray/core_worker/transport/direct_actor_task_submitter.cc
index ce098aa026024..e480456d67aab 100644
--- a/src/ray/core_worker/transport/direct_actor_task_submitter.cc
+++ b/src/ray/core_worker/transport/direct_actor_task_submitter.cc
@@ -469,7 +469,8 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(ClientQueue &queue,
};
task_finisher_.MarkTaskWaitingForExecution(task_id,
- NodeID::FromBinary(addr.raylet_id()));
+ NodeID::FromBinary(addr.raylet_id()),
+ WorkerID::FromBinary(addr.worker_id()));
queue.rpc_client->PushActorTask(std::move(request), skip_queue, wrapped_callback);
}
diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc
index e3e4536c477b1..eb5982d269b91 100644
--- a/src/ray/core_worker/transport/direct_task_transport.cc
+++ b/src/ray/core_worker/transport/direct_task_transport.cc
@@ -565,7 +565,7 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(
request->mutable_task_spec()->CopyFrom(task_spec.GetMessage());
request->mutable_resource_mapping()->CopyFrom(assigned_resources);
request->set_intended_worker_id(addr.worker_id.Binary());
- task_finisher_->MarkTaskWaitingForExecution(task_id, addr.raylet_id);
+ task_finisher_->MarkTaskWaitingForExecution(task_id, addr.raylet_id, addr.worker_id);
client.PushNormalTask(
std::move(request),
[this,
diff --git a/src/ray/gcs/gcs_client/accessor.cc b/src/ray/gcs/gcs_client/accessor.cc
index c6631eb9a5573..c0b936cda284f 100644
--- a/src/ray/gcs/gcs_client/accessor.cc
+++ b/src/ray/gcs/gcs_client/accessor.cc
@@ -596,7 +596,7 @@ void NodeInfoAccessor::HandleNotification(const GcsNodeInfo &node_info) {
} else {
node.set_node_id(node_info.node_id());
node.set_state(rpc::GcsNodeInfo::DEAD);
- node.set_timestamp(node_info.timestamp());
+ node.set_end_time_ms(node_info.end_time_ms());
}
// If the notification is new, call registered callback.
diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.h b/src/ray/gcs/gcs_server/gcs_actor_manager.h
index 2ce990e7988a2..adbc04b90ee2b 100644
--- a/src/ray/gcs/gcs_server/gcs_actor_manager.h
+++ b/src/ray/gcs/gcs_server/gcs_actor_manager.h
@@ -104,6 +104,13 @@ class GcsActor {
actor_table_data_.mutable_address()->set_worker_id(WorkerID::Nil().Binary());
actor_table_data_.set_ray_namespace(ray_namespace);
+ if (task_spec.scheduling_strategy().scheduling_strategy_case() ==
+ rpc::SchedulingStrategy::SchedulingStrategyCase::
+ kPlacementGroupSchedulingStrategy) {
+ actor_table_data_.set_placement_group_id(task_spec.scheduling_strategy()
+ .placement_group_scheduling_strategy()
+ .placement_group_id());
+ }
// Set required resources.
auto resource_map =
diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc
index 6d98d8b02aa29..0d08b337e8d54 100644
--- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc
+++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc
@@ -49,7 +49,8 @@ GcsActorScheduler::GcsActorScheduler(
void GcsActorScheduler::Schedule(std::shared_ptr actor) {
RAY_CHECK(actor->GetNodeID().IsNil() && actor->GetWorkerID().IsNil());
- if (RayConfig::instance().gcs_actor_scheduling_enabled()) {
+ if (RayConfig::instance().gcs_actor_scheduling_enabled() &&
+ !actor->GetCreationTaskSpecification().GetRequiredResources().IsEmpty()) {
ScheduleByGcs(actor);
} else {
ScheduleByRaylet(actor);
@@ -93,7 +94,10 @@ void GcsActorScheduler::ScheduleByGcs(std::shared_ptr actor) {
};
// Queue and schedule the actor locally (gcs).
- cluster_task_manager_->QueueAndScheduleTask(actor->GetCreationTaskSpecification(),
+ const auto &owner_node = gcs_node_manager_.GetAliveNode(actor->GetOwnerNodeID());
+ RayTask task(actor->GetCreationTaskSpecification(),
+ owner_node.has_value() ? actor->GetOwnerNodeID().Binary() : std::string());
+ cluster_task_manager_->QueueAndScheduleTask(task,
/*grant_or_reject*/ false,
/*is_selected_based_on_locality*/ false,
/*reply*/ reply.get(),
diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.cc b/src/ray/gcs/gcs_server/gcs_job_manager.cc
index 938de847176af..bed2b0298fd0d 100644
--- a/src/ray/gcs/gcs_server/gcs_job_manager.cc
+++ b/src/ray/gcs/gcs_server/gcs_job_manager.cc
@@ -82,7 +82,7 @@ void GcsJobManager::MarkJobAsFinished(rpc::JobTableData job_table_data,
} else {
RAY_CHECK_OK(gcs_publisher_->PublishJob(job_id, job_table_data, nullptr));
runtime_env_manager_.RemoveURIReference(job_id.Hex());
- ClearJobInfos(job_id);
+ ClearJobInfos(job_table_data);
RAY_LOG(INFO) << "Finished marking job state, job id = " << job_id;
}
function_manager_.RemoveJobReference(job_id);
@@ -121,10 +121,10 @@ void GcsJobManager::HandleMarkJobFinished(rpc::MarkJobFinishedRequest request,
}
}
-void GcsJobManager::ClearJobInfos(const JobID &job_id) {
+void GcsJobManager::ClearJobInfos(const rpc::JobTableData &job_data) {
// Notify all listeners.
for (auto &listener : job_finished_listeners_) {
- listener(std::make_shared(job_id));
+ listener(job_data);
}
// Clear cache.
// TODO(qwang): This line will cause `test_actor_advanced.py::test_detached_actor`
@@ -137,8 +137,7 @@ void GcsJobManager::ClearJobInfos(const JobID &job_id) {
/// Add listener to monitor the add action of nodes.
///
/// \param listener The handler which process the add of nodes.
-void GcsJobManager::AddJobFinishedListener(
- std::function)> listener) {
+void GcsJobManager::AddJobFinishedListener(JobFinishListenerCallback listener) {
RAY_CHECK(listener);
job_finished_listeners_.emplace_back(std::move(listener));
}
diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.h b/src/ray/gcs/gcs_server/gcs_job_manager.h
index b3ac1c15055ee..a7c0c25ec997c 100644
--- a/src/ray/gcs/gcs_server/gcs_job_manager.h
+++ b/src/ray/gcs/gcs_server/gcs_job_manager.h
@@ -24,6 +24,8 @@
namespace ray {
namespace gcs {
+using JobFinishListenerCallback = rpc::JobInfoHandler::JobFinishListenerCallback;
+
/// This implementation class of `JobInfoHandler`.
class GcsJobManager : public rpc::JobInfoHandler {
public:
@@ -58,8 +60,7 @@ class GcsJobManager : public rpc::JobInfoHandler {
rpc::GetNextJobIDReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
- void AddJobFinishedListener(
- std::function)> listener) override;
+ void AddJobFinishedListener(JobFinishListenerCallback listener) override;
std::shared_ptr GetJobConfig(const JobID &job_id) const;
@@ -68,14 +69,14 @@ class GcsJobManager : public rpc::JobInfoHandler {
std::shared_ptr gcs_publisher_;
/// Listeners which monitors the finish of jobs.
- std::vector)>> job_finished_listeners_;
+ std::vector job_finished_listeners_;
/// A cached mapping from job id to job config.
absl::flat_hash_map> cached_job_configs_;
ray::RuntimeEnvManager &runtime_env_manager_;
GcsFunctionManager &function_manager_;
- void ClearJobInfos(const JobID &job_id);
+ void ClearJobInfos(const rpc::JobTableData &job_data);
void MarkJobAsFinished(rpc::JobTableData job_table_data,
std::function done_callback);
diff --git a/src/ray/gcs/gcs_server/gcs_kv_manager.cc b/src/ray/gcs/gcs_server/gcs_kv_manager.cc
index 75ff7af7d45b2..33a5d9ae96bb3 100644
--- a/src/ray/gcs/gcs_server/gcs_kv_manager.cc
+++ b/src/ray/gcs/gcs_server/gcs_kv_manager.cc
@@ -21,182 +21,6 @@
namespace ray {
namespace gcs {
-namespace {
-
-constexpr std::string_view kNamespacePrefix = "@namespace_";
-constexpr std::string_view kNamespaceSep = ":";
-constexpr std::string_view kClusterSeparator = "@";
-
-} // namespace
-std::string RedisInternalKV::MakeKey(const std::string &ns,
- const std::string &key) const {
- if (ns.empty()) {
- return absl::StrCat(external_storage_namespace_, kClusterSeparator, key);
- }
- return absl::StrCat(external_storage_namespace_,
- kClusterSeparator,
- kNamespacePrefix,
- ns,
- kNamespaceSep,
- key);
-}
-
-Status RedisInternalKV::ValidateKey(const std::string &key) const {
- if (absl::StartsWith(key, kNamespacePrefix)) {
- return Status::KeyError(absl::StrCat("Key can't start with ", kNamespacePrefix));
- }
- return Status::OK();
-}
-
-std::string RedisInternalKV::ExtractKey(const std::string &key) const {
- auto view = std::string_view(key);
- RAY_CHECK(absl::StartsWith(view, external_storage_namespace_))
- << "Invalid key: " << view << ". It should start with "
- << external_storage_namespace_;
- view = view.substr(external_storage_namespace_.size() + kClusterSeparator.size());
- if (absl::StartsWith(view, kNamespacePrefix)) {
- std::vector parts =
- absl::StrSplit(key, absl::MaxSplits(kNamespaceSep, 1));
- RAY_CHECK(parts.size() == 2) << "Invalid key: " << key;
-
- return parts[1];
- }
- return std::string(view.begin(), view.end());
-}
-
-RedisInternalKV::RedisInternalKV(const RedisClientOptions &redis_options)
- : redis_options_(redis_options),
- external_storage_namespace_(::RayConfig::instance().external_storage_namespace()),
- work_(io_service_) {
- RAY_CHECK(!absl::StrContains(external_storage_namespace_, kClusterSeparator))
- << "Storage namespace (" << external_storage_namespace_ << ") shouldn't contain "
- << kClusterSeparator << ".";
- io_thread_ = std::make_unique([this] {
- SetThreadName("InternalKV");
- io_service_.run();
- });
- redis_client_ = std::make_unique(redis_options_);
- RAY_CHECK_OK(redis_client_->Connect(io_service_));
-}
-
-void RedisInternalKV::Get(const std::string &ns,
- const std::string &key,
- std::function)> callback) {
- auto true_key = MakeKey(ns, key);
- std::vector cmd = {"HGET", external_storage_namespace_, true_key};
- RAY_CHECK_OK(redis_client_->GetPrimaryContext()->RunArgvAsync(
- cmd, [callback = std::move(callback)](auto redis_reply) {
- if (callback) {
- if (!redis_reply->IsNil()) {
- callback(redis_reply->ReadAsString());
- } else {
- callback(std::nullopt);
- }
- }
- }));
-}
-
-void RedisInternalKV::Put(const std::string &ns,
- const std::string &key,
- const std::string &value,
- bool overwrite,
- std::function callback) {
- auto true_key = MakeKey(ns, key);
- std::vector cmd = {
- overwrite ? "HSET" : "HSETNX", external_storage_namespace_, true_key, value};
- RAY_CHECK_OK(redis_client_->GetPrimaryContext()->RunArgvAsync(
- cmd, [callback = std::move(callback)](auto redis_reply) {
- if (callback) {
- auto added_num = redis_reply->ReadAsInteger();
- callback(added_num != 0);
- }
- }));
-}
-
-void RedisInternalKV::Del(const std::string &ns,
- const std::string &key,
- bool del_by_prefix,
- std::function callback) {
- auto true_key = MakeKey(ns, key);
- if (del_by_prefix) {
- std::vector cmd = {"HKEYS", external_storage_namespace_};
- RAY_CHECK_OK(redis_client_->GetPrimaryContext()->RunArgvAsync(
- cmd,
- [this, true_key = std::move(true_key), callback = std::move(callback)](
- auto redis_reply) {
- const auto &reply = redis_reply->ReadAsStringArray();
- std::vector del_cmd = {"HDEL", external_storage_namespace_};
- size_t del_num = 0;
- for (const auto &r : reply) {
- RAY_CHECK(r.has_value());
- if (absl::StartsWith(*r, true_key)) {
- del_cmd.emplace_back(*r);
- ++del_num;
- }
- }
-
- // If there are no keys with this prefix, we don't need to send
- // another delete.
- if (del_num == 0) {
- if (callback) {
- callback(0);
- }
- } else {
- RAY_CHECK_OK(redis_client_->GetPrimaryContext()->RunArgvAsync(
- del_cmd, [callback = std::move(callback)](auto redis_reply) {
- if (callback) {
- callback(redis_reply->ReadAsInteger());
- }
- }));
- }
- }));
- } else {
- std::vector cmd = {"HDEL", external_storage_namespace_, true_key};
- RAY_CHECK_OK(redis_client_->GetPrimaryContext()->RunArgvAsync(
- cmd, [callback = std::move(callback)](auto redis_reply) {
- if (callback) {
- callback(redis_reply->ReadAsInteger());
- }
- }));
- }
-}
-
-void RedisInternalKV::Exists(const std::string &ns,
- const std::string &key,
- std::function callback) {
- auto true_key = MakeKey(ns, key);
- std::vector cmd = {"HEXISTS", external_storage_namespace_, true_key};
- RAY_CHECK_OK(redis_client_->GetPrimaryContext()->RunArgvAsync(
- cmd, [callback = std::move(callback)](auto redis_reply) {
- if (callback) {
- bool exists = redis_reply->ReadAsInteger() > 0;
- callback(exists);
- }
- }));
-}
-
-void RedisInternalKV::Keys(const std::string &ns,
- const std::string &prefix,
- std::function)> callback) {
- auto true_prefix = MakeKey(ns, prefix);
- std::vector cmd = {"HKEYS", external_storage_namespace_};
- RAY_CHECK_OK(redis_client_->GetPrimaryContext()->RunArgvAsync(
- cmd,
- [this, true_prefix = std::move(true_prefix), callback = std::move(callback)](
- auto redis_reply) {
- if (callback) {
- const auto &reply = redis_reply->ReadAsStringArray();
- std::vector results;
- for (const auto &r : reply) {
- RAY_CHECK(r.has_value());
- if (absl::StartsWith(*r, true_prefix)) {
- results.emplace_back(ExtractKey(*r));
- }
- }
- callback(std::move(results));
- }
- }));
-}
void GcsInternalKVManager::HandleInternalKVGet(
rpc::InternalKVGetRequest request,
@@ -294,6 +118,7 @@ void GcsInternalKVManager::HandleInternalKVKeys(
}
Status GcsInternalKVManager::ValidateKey(const std::string &key) const {
+ constexpr std::string_view kNamespacePrefix = "@namespace_";
if (absl::StartsWith(key, kNamespacePrefix)) {
return Status::KeyError(absl::StrCat("Key can't start with ", kNamespacePrefix));
}
diff --git a/src/ray/gcs/gcs_server/gcs_kv_manager.h b/src/ray/gcs/gcs_server/gcs_kv_manager.h
index 341a855b25d0f..9ccc9c04a604e 100644
--- a/src/ray/gcs/gcs_server/gcs_kv_manager.h
+++ b/src/ray/gcs/gcs_server/gcs_kv_manager.h
@@ -87,54 +87,6 @@ class InternalKVInterface {
virtual ~InternalKVInterface(){};
};
-class RedisInternalKV : public InternalKVInterface {
- public:
- explicit RedisInternalKV(const RedisClientOptions &redis_options);
-
- ~RedisInternalKV() {
- io_service_.stop();
- io_thread_->join();
- redis_client_.reset();
- io_thread_.reset();
- }
-
- void Get(const std::string &ns,
- const std::string &key,
- std::function)> callback) override;
-
- void Put(const std::string &ns,
- const std::string &key,
- const std::string &value,
- bool overwrite,
- std::function callback) override;
-
- void Del(const std::string &ns,
- const std::string &key,
- bool del_by_prefix,
- std::function callback) override;
-
- void Exists(const std::string &ns,
- const std::string &key,
- std::function callback) override;
-
- void Keys(const std::string &ns,
- const std::string &prefix,
- std::function)> callback) override;
-
- private:
- std::string MakeKey(const std::string &ns, const std::string &key) const;
- Status ValidateKey(const std::string &key) const;
- std::string ExtractKey(const std::string &key) const;
-
- RedisClientOptions redis_options_;
- std::string external_storage_namespace_;
- std::unique_ptr redis_client_;
- // The io service used by internal kv.
- instrumented_io_context io_service_;
- std::unique_ptr io_thread_;
- boost::asio::io_service::work work_;
-};
-
/// This implementation class of `InternalKVHandler`.
class GcsInternalKVManager : public rpc::InternalKVHandler {
public:
diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.cc b/src/ray/gcs/gcs_server/gcs_node_manager.cc
index 13fb0459ef5c8..31363a4c07c47 100644
--- a/src/ray/gcs/gcs_server/gcs_node_manager.cc
+++ b/src/ray/gcs/gcs_server/gcs_node_manager.cc
@@ -94,12 +94,12 @@ void GcsNodeManager::DrainNode(const NodeID &node_id) {
// Do the procedure to drain a node.
node->set_state(rpc::GcsNodeInfo::DEAD);
- node->set_timestamp(current_sys_time_ms());
+ node->set_end_time_ms(current_sys_time_ms());
AddDeadNodeToCache(node);
auto node_info_delta = std::make_shared();
node_info_delta->set_node_id(node->node_id());
node_info_delta->set_state(node->state());
- node_info_delta->set_timestamp(node->timestamp());
+ node_info_delta->set_end_time_ms(node->end_time_ms());
// Set the address.
rpc::Address remote_address;
remote_address.set_raylet_id(node->node_id());
@@ -251,12 +251,12 @@ std::shared_ptr GcsNodeManager::RemoveNode(
void GcsNodeManager::OnNodeFailure(const NodeID &node_id) {
if (auto node = RemoveNode(node_id, /* is_intended = */ false)) {
node->set_state(rpc::GcsNodeInfo::DEAD);
- node->set_timestamp(current_sys_time_ms());
+ node->set_end_time_ms(current_sys_time_ms());
AddDeadNodeToCache(node);
auto node_info_delta = std::make_shared();
node_info_delta->set_node_id(node->node_id());
node_info_delta->set_state(node->state());
- node_info_delta->set_timestamp(node->timestamp());
+ node_info_delta->set_end_time_ms(node->end_time_ms());
auto on_done = [this, node_id, node_info_delta](const Status &status) {
auto on_done = [this, node_id, node_info_delta](const Status &status) {
@@ -288,7 +288,7 @@ void GcsNodeManager::Initialize(const GcsInitData &gcs_init_data) {
raylet_client->NotifyGCSRestart(nullptr);
} else if (node_info.state() == rpc::GcsNodeInfo::DEAD) {
dead_nodes_.emplace(node_id, std::make_shared(node_info));
- sorted_dead_node_list_.emplace_back(node_id, node_info.timestamp());
+ sorted_dead_node_list_.emplace_back(node_id, node_info.end_time_ms());
}
}
sorted_dead_node_list_.sort(
@@ -305,7 +305,7 @@ void GcsNodeManager::AddDeadNodeToCache(std::shared_ptr node)
}
auto node_id = NodeID::FromBinary(node->node_id());
dead_nodes_.emplace(node_id, node);
- sorted_dead_node_list_.emplace_back(node_id, node->timestamp());
+ sorted_dead_node_list_.emplace_back(node_id, node->end_time_ms());
}
std::string GcsNodeManager::DebugString() const {
diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc
index 222dff1b9e88e..706a483c3760b 100644
--- a/src/ray/gcs/gcs_server/gcs_server.cc
+++ b/src/ray/gcs/gcs_server/gcs_server.cc
@@ -668,9 +668,11 @@ void GcsServer::InstallEventListeners() {
});
// Install job event listeners.
- gcs_job_manager_->AddJobFinishedListener([this](std::shared_ptr job_id) {
- gcs_actor_manager_->OnJobFinished(*job_id);
- gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenJobDead(*job_id);
+ gcs_job_manager_->AddJobFinishedListener([this](const rpc::JobTableData &job_data) {
+ const auto job_id = JobID::FromBinary(job_data.job_id());
+ gcs_actor_manager_->OnJobFinished(job_id);
+ gcs_task_manager_->OnJobFinished(job_id, job_data.end_time());
+ gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenJobDead(job_id);
});
// Install scheduling event listeners.
diff --git a/src/ray/gcs/gcs_server/gcs_task_manager.cc b/src/ray/gcs/gcs_server/gcs_task_manager.cc
index c066645e01500..68dc813758cc0 100644
--- a/src/ray/gcs/gcs_server/gcs_task_manager.cc
+++ b/src/ray/gcs/gcs_server/gcs_task_manager.cc
@@ -111,6 +111,22 @@ const rpc::TaskEvents &GcsTaskManager::GcsTaskManagerStorage::GetTaskEvent(
return task_events_.at(idx_itr->second);
}
+void GcsTaskManager::GcsTaskManagerStorage::MarkTaskAttemptFailed(
+ const TaskAttempt &task_attempt, int64_t failed_ts) {
+ auto &task_event = GetTaskEvent(task_attempt);
+ if (!task_event.has_state_updates()) {
+ return;
+ }
+ task_event.mutable_state_updates()->set_failed_ts(failed_ts);
+}
+
+bool GcsTaskManager::GcsTaskManagerStorage::IsTaskTerminated(
+ const TaskID &task_id) const {
+ auto failed_ts = GetTaskStatusUpdateTime(task_id, rpc::TaskStatus::FAILED);
+ auto finished_ts = GetTaskStatusUpdateTime(task_id, rpc::TaskStatus::FINISHED);
+ return failed_ts.has_value() || finished_ts.has_value();
+}
+
absl::optional GcsTaskManager::GcsTaskManagerStorage::GetTaskStatusUpdateTime(
const TaskID &task_id, const rpc::TaskStatus &task_status) const {
auto latest_task_attempt = GetLatestTaskAttempt(task_id);
@@ -124,15 +140,29 @@ absl::optional GcsTaskManager::GcsTaskManagerStorage::GetTaskStatusUpda
: absl::nullopt;
}
+void GcsTaskManager::GcsTaskManagerStorage::MarkTasksFailed(const JobID &job_id,
+ int64_t job_finish_time_ns) {
+ auto task_attempts_itr = job_to_task_attempt_index_.find(job_id);
+ if (task_attempts_itr == job_to_task_attempt_index_.end()) {
+ // No tasks in the job.
+ return;
+ }
+
+ // Iterate all task attempts from the job.
+ for (const auto &task_attempt : task_attempts_itr->second) {
+ if (!IsTaskTerminated(task_attempt.first)) {
+ MarkTaskAttemptFailed(task_attempt, job_finish_time_ns);
+ }
+ }
+}
+
void GcsTaskManager::GcsTaskManagerStorage::MarkTaskFailed(const TaskID &task_id,
int64_t failed_ts) {
auto latest_task_attempt = GetLatestTaskAttempt(task_id);
if (!latest_task_attempt.has_value()) {
return;
}
- auto &task_event = GetTaskEvent(*latest_task_attempt);
- task_event.mutable_state_updates()->set_failed_ts(failed_ts);
- task_event.mutable_state_updates()->clear_finished_ts();
+ MarkTaskAttemptFailed(*latest_task_attempt, failed_ts);
}
void GcsTaskManager::GcsTaskManagerStorage::MarkTaskTreeFailedIfNeeded(
@@ -161,11 +191,8 @@ void GcsTaskManager::GcsTaskManagerStorage::MarkTaskTreeFailedIfNeeded(
continue;
}
for (const auto &child_task_id : children_tasks_itr->second) {
- // Mark any non-terminated child as failed with parent's (or ancestor's) failure
- // timestamp.
- if (!(GetTaskStatusUpdateTime(child_task_id, rpc::TaskStatus::FAILED).has_value() ||
- GetTaskStatusUpdateTime(child_task_id, rpc::TaskStatus::FINISHED)
- .has_value())) {
+ // Mark any non-terminated child as failed with parent's failure timestamp.
+ if (!IsTaskTerminated(child_task_id)) {
MarkTaskFailed(child_task_id, task_failed_ts.value());
failed_tasks.push_back(child_task_id);
}
@@ -349,17 +376,14 @@ void GcsTaskManager::HandleAddTaskEventData(rpc::AddTaskEventDataRequest request
rpc::AddTaskEventDataReply *reply,
rpc::SendReplyCallback send_reply_callback) {
absl::MutexLock lock(&mutex_);
- RAY_LOG(DEBUG) << "Adding task state event:" << request.data().ShortDebugString();
// Dispatch to the handler
auto data = std::move(request.data());
- size_t num_to_process = data.events_by_task_size();
// Update counters.
total_num_profile_task_events_dropped_ += data.num_profile_task_events_dropped();
total_num_status_task_events_dropped_ += data.num_status_task_events_dropped();
for (auto events_by_task : *data.mutable_events_by_task()) {
total_num_task_events_reported_++;
- auto task_id = TaskID::FromBinary(events_by_task.task_id());
// TODO(rickyx): add logic to handle too many profile events for a single task
// attempt. https://github.com/ray-project/ray/issues/31279
@@ -378,11 +402,9 @@ void GcsTaskManager::HandleAddTaskEventData(rpc::AddTaskEventDataRequest request
replaced_task_events->profile_events().events_size();
}
}
- RAY_LOG(DEBUG) << "Processed a task event. [task_id=" << task_id.Hex() << "]";
}
// Processed all the task events
- RAY_LOG(DEBUG) << "Processed all " << num_to_process << " task events";
GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
}
@@ -418,5 +440,22 @@ void GcsTaskManager::RecordMetrics() {
task_event_storage_->GetTaskEventsBytes());
}
+void GcsTaskManager::OnJobFinished(const JobID &job_id, int64_t job_finish_time_ms) {
+ RAY_LOG(DEBUG) << "Marking all running tasks of job " << job_id.Hex() << " as failed.";
+ timer_.expires_from_now(boost::posix_time::milliseconds(
+ RayConfig::instance().gcs_mark_task_failed_on_job_done_delay_ms()));
+ timer_.async_wait(
+ [this, job_id, job_finish_time_ms](const boost::system::error_code &error) {
+ if (error == boost::asio::error::operation_aborted) {
+ // timer canceled or aborted.
+ return;
+ }
+ absl::MutexLock lock(&mutex_);
+ // If there are any non-terminated tasks from the job, mark them failed since all
+ // workers associated with the job will be killed.
+ task_event_storage_->MarkTasksFailed(job_id, job_finish_time_ms * 1000);
+ });
+}
+
} // namespace gcs
} // namespace ray
diff --git a/src/ray/gcs/gcs_server/gcs_task_manager.h b/src/ray/gcs/gcs_server/gcs_task_manager.h
index 43ca174b90989..afd49d9604ebf 100644
--- a/src/ray/gcs/gcs_server/gcs_task_manager.h
+++ b/src/ray/gcs/gcs_server/gcs_task_manager.h
@@ -48,7 +48,8 @@ class GcsTaskManager : public rpc::TaskInfoHandler {
// Keep io_service_ alive.
boost::asio::io_service::work io_service_work_(io_service_);
io_service_.run();
- })) {}
+ })),
+ timer_(io_service_) {}
/// Handles a AddTaskEventData request.
///
@@ -76,6 +77,13 @@ class GcsTaskManager : public rpc::TaskInfoHandler {
/// This function returns when the io thread is joined.
void Stop() LOCKS_EXCLUDED(mutex_);
+ /// Handler to be called when a job finishes. This marks all non-terminated tasks
+ /// of the job as failed.
+ ///
+ /// \param job_id Job Id
+ /// \param job_finish_time_ms Job finish time in ms.
+ void OnJobFinished(const JobID &job_id, int64_t job_finish_time_ms);
+
/// Returns the io_service.
///
/// \return Reference to its io_service.
@@ -146,6 +154,13 @@ class GcsTaskManager : public rpc::TaskInfoHandler {
std::vector GetTaskEvents(
const absl::flat_hash_set &task_attempts) const;
+ /// Mark tasks from a job as failed.
+ ///
+ /// \param job_id Job ID
+ /// \param job_finish_time_ns job finished time in nanoseconds, which will be the task
+ /// failed time.
+ void MarkTasksFailed(const JobID &job_id, int64_t job_finish_time_ns);
+
private:
/// Mark the task tree containing this task attempt as failure if necessary.
///
@@ -192,6 +207,12 @@ class GcsTaskManager : public rpc::TaskInfoHandler {
absl::optional GetTaskStatusUpdateTime(
const TaskID &task_id, const rpc::TaskStatus &task_status) const;
+ /// Return if task has terminated.
+ ///
+ /// \param task_id Task id
+ /// \return True if the task has finished or failed timestamp sets, false otherwise.
+ bool IsTaskTerminated(const TaskID &task_id) const;
+
/// Mark the task as failure with the failed timestamp.
///
/// This also overwrites the finished state of the task if the task has finished by
@@ -202,6 +223,12 @@ class GcsTaskManager : public rpc::TaskInfoHandler {
/// timestamp.
void MarkTaskFailed(const TaskID &task_id, int64_t failed_ts);
+ /// Mark a task attempt as failed.
+ ///
+ /// \param task_attempt Task attempt.
+ /// \param failed_ts The failure timestamp.
+ void MarkTaskAttemptFailed(const TaskAttempt &task_attempt, int64_t failed_ts);
+
/// Get the latest task attempt for the task.
///
/// If there is no such task or data loss due to task events dropped at the worker,
@@ -279,10 +306,14 @@ class GcsTaskManager : public rpc::TaskInfoHandler {
/// Its own IO thread from the main thread.
std::unique_ptr io_service_thread_;
+ /// Timer for delay functions.
+ boost::asio::deadline_timer timer_;
+
FRIEND_TEST(GcsTaskManagerTest, TestHandleAddTaskEventBasic);
FRIEND_TEST(GcsTaskManagerTest, TestMergeTaskEventsSameTaskAttempt);
FRIEND_TEST(GcsTaskManagerMemoryLimitedTest, TestLimitTaskEvents);
FRIEND_TEST(GcsTaskManagerMemoryLimitedTest, TestIndexNoLeak);
+ FRIEND_TEST(GcsTaskManagerTest, TestJobFinishesFailAllRunningTasks);
};
} // namespace gcs
diff --git a/src/ray/gcs/gcs_server/gcs_worker_manager.cc b/src/ray/gcs/gcs_server/gcs_worker_manager.cc
index 364735cc699c2..6e1a5846c2b0d 100644
--- a/src/ray/gcs/gcs_server/gcs_worker_manager.cc
+++ b/src/ray/gcs/gcs_server/gcs_worker_manager.cc
@@ -25,87 +25,106 @@ void GcsWorkerManager::HandleReportWorkerFailure(
rpc::SendReplyCallback send_reply_callback) {
const rpc::Address worker_address = request.worker_failure().worker_address();
const auto worker_id = WorkerID::FromBinary(worker_address.worker_id());
- const auto node_id = NodeID::FromBinary(worker_address.raylet_id());
- std::string message =
- absl::StrCat("Reporting worker exit, worker id = ",
- worker_id.Hex(),
- ", node id = ",
- node_id.Hex(),
- ", address = ",
- worker_address.ip_address(),
- ", exit_type = ",
- rpc::WorkerExitType_Name(request.worker_failure().exit_type()),
- ", exit_detail = ",
- request.worker_failure().exit_detail());
- if (request.worker_failure().exit_type() == rpc::WorkerExitType::INTENDED_USER_EXIT ||
- request.worker_failure().exit_type() == rpc::WorkerExitType::INTENDED_SYSTEM_EXIT) {
- RAY_LOG(DEBUG) << message;
- } else {
- RAY_LOG(WARNING) << message
- << ". Unintentional worker failures have been reported. If there "
- "are lots of this logs, that might indicate there are "
- "unexpected failures in the cluster.";
- }
- auto worker_failure_data = std::make_shared();
- worker_failure_data->CopyFrom(request.worker_failure());
- worker_failure_data->set_is_alive(false);
-
- for (auto &listener : worker_dead_listeners_) {
- listener(worker_failure_data);
- }
+ GetWorkerInfo(
+ worker_id,
+ [this,
+ reply,
+ send_reply_callback,
+ worker_id = std::move(worker_id),
+ request = std::move(request),
+ worker_address =
+ std::move(worker_address)](const boost::optional &result) {
+ const auto node_id = NodeID::FromBinary(worker_address.raylet_id());
+ std::string message =
+ absl::StrCat("Reporting worker exit, worker id = ",
+ worker_id.Hex(),
+ ", node id = ",
+ node_id.Hex(),
+ ", address = ",
+ worker_address.ip_address(),
+ ", exit_type = ",
+ rpc::WorkerExitType_Name(request.worker_failure().exit_type()),
+ ", exit_detail = ",
+ request.worker_failure().exit_detail());
+ if (request.worker_failure().exit_type() ==
+ rpc::WorkerExitType::INTENDED_USER_EXIT ||
+ request.worker_failure().exit_type() ==
+ rpc::WorkerExitType::INTENDED_SYSTEM_EXIT) {
+ RAY_LOG(DEBUG) << message;
+ } else {
+ RAY_LOG(WARNING)
+ << message
+ << ". Unintentional worker failures have been reported. If there "
+ "are lots of this logs, that might indicate there are "
+ "unexpected failures in the cluster.";
+ }
+ auto worker_failure_data = std::make_shared();
+ if (result) {
+ worker_failure_data->CopyFrom(*result);
+ }
+ worker_failure_data->MergeFrom(request.worker_failure());
+ worker_failure_data->set_is_alive(false);
- auto on_done = [this,
- worker_address,
- worker_id,
- node_id,
- worker_failure_data,
- reply,
- send_reply_callback](const Status &status) {
- if (!status.ok()) {
- RAY_LOG(ERROR) << "Failed to report worker failure, worker id = " << worker_id
- << ", node id = " << node_id
- << ", address = " << worker_address.ip_address();
- } else {
- stats::UnintentionalWorkerFailures.Record(1);
- // Only publish worker_id and raylet_id in address as they are the only fields used
- // by sub clients.
- rpc::WorkerDeltaData worker_failure;
- worker_failure.set_worker_id(worker_failure_data->worker_address().worker_id());
- worker_failure.set_raylet_id(worker_failure_data->worker_address().raylet_id());
- RAY_CHECK_OK(
- gcs_publisher_->PublishWorkerFailure(worker_id, worker_failure, nullptr));
- }
- GCS_RPC_SEND_REPLY(send_reply_callback, reply, status);
- };
+ for (auto &listener : worker_dead_listeners_) {
+ listener(worker_failure_data);
+ }
- // As soon as the worker starts, it will register with GCS. It ensures that GCS receives
- // the worker registration information first and then the worker failure message, so we
- // delete the get operation. Related issues:
- // https://github.com/ray-project/ray/pull/11599
- Status status =
- gcs_table_storage_->WorkerTable().Put(worker_id, *worker_failure_data, on_done);
- if (!status.ok()) {
- on_done(status);
- }
+ auto on_done = [this,
+ worker_address,
+ worker_id,
+ node_id,
+ worker_failure_data,
+ reply,
+ send_reply_callback](const Status &status) {
+ if (!status.ok()) {
+ RAY_LOG(ERROR) << "Failed to report worker failure, worker id = " << worker_id
+ << ", node id = " << node_id
+ << ", address = " << worker_address.ip_address();
+ } else {
+ stats::UnintentionalWorkerFailures.Record(1);
+ // Only publish worker_id and raylet_id in address as they are the only fields
+ // used by sub clients.
+ rpc::WorkerDeltaData worker_failure;
+ worker_failure.set_worker_id(
+ worker_failure_data->worker_address().worker_id());
+ worker_failure.set_raylet_id(
+ worker_failure_data->worker_address().raylet_id());
+ RAY_CHECK_OK(
+ gcs_publisher_->PublishWorkerFailure(worker_id, worker_failure, nullptr));
+ }
+ GCS_RPC_SEND_REPLY(send_reply_callback, reply, status);
+ };
+
+ // As soon as the worker starts, it will register with GCS. It ensures that GCS
+ // receives the worker registration information first and then the worker failure
+ // message, so we delete the get operation. Related issues:
+ // https://github.com/ray-project/ray/pull/11599
+ Status status = gcs_table_storage_->WorkerTable().Put(
+ worker_id, *worker_failure_data, on_done);
+ if (!status.ok()) {
+ on_done(status);
+ }
- if (request.worker_failure().exit_type() == rpc::WorkerExitType::SYSTEM_ERROR ||
- request.worker_failure().exit_type() == rpc::WorkerExitType::NODE_OUT_OF_MEMORY) {
- usage::TagKey key;
- int count = 0;
- if (request.worker_failure().exit_type() == rpc::WorkerExitType::SYSTEM_ERROR) {
- worker_crash_system_error_count_ += 1;
- key = usage::TagKey::WORKER_CRASH_SYSTEM_ERROR;
- count = worker_crash_system_error_count_;
- } else if (request.worker_failure().exit_type() ==
- rpc::WorkerExitType::NODE_OUT_OF_MEMORY) {
- worker_crash_oom_count_ += 1;
- key = usage::TagKey::WORKER_CRASH_OOM;
- count = worker_crash_oom_count_;
- }
- if (usage_stats_client_) {
- usage_stats_client_->RecordExtraUsageCounter(key, count);
- }
- }
+ if (request.worker_failure().exit_type() == rpc::WorkerExitType::SYSTEM_ERROR ||
+ request.worker_failure().exit_type() ==
+ rpc::WorkerExitType::NODE_OUT_OF_MEMORY) {
+ usage::TagKey key;
+ int count = 0;
+ if (request.worker_failure().exit_type() == rpc::WorkerExitType::SYSTEM_ERROR) {
+ worker_crash_system_error_count_ += 1;
+ key = usage::TagKey::WORKER_CRASH_SYSTEM_ERROR;
+ count = worker_crash_system_error_count_;
+ } else if (request.worker_failure().exit_type() ==
+ rpc::WorkerExitType::NODE_OUT_OF_MEMORY) {
+ worker_crash_oom_count_ += 1;
+ key = usage::TagKey::WORKER_CRASH_OOM;
+ count = worker_crash_oom_count_;
+ }
+ if (usage_stats_client_) {
+ usage_stats_client_->RecordExtraUsageCounter(key, count);
+ }
+ }
+ });
}
void GcsWorkerManager::HandleGetWorkerInfo(rpc::GetWorkerInfoRequest request,
@@ -114,20 +133,16 @@ void GcsWorkerManager::HandleGetWorkerInfo(rpc::GetWorkerInfoRequest request,
WorkerID worker_id = WorkerID::FromBinary(request.worker_id());
RAY_LOG(DEBUG) << "Getting worker info, worker id = " << worker_id;
- auto on_done = [worker_id, reply, send_reply_callback](
- const Status &status,
- const boost::optional &result) {
- if (result) {
- reply->mutable_worker_table_data()->CopyFrom(*result);
- }
- RAY_LOG(DEBUG) << "Finished getting worker info, worker id = " << worker_id;
- GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
- };
-
- Status status = gcs_table_storage_->WorkerTable().Get(worker_id, on_done);
- if (!status.ok()) {
- on_done(status, boost::none);
- }
+ GetWorkerInfo(worker_id,
+ [reply, send_reply_callback, worker_id = std::move(worker_id)](
+ const boost::optional &result) {
+ if (result) {
+ reply->mutable_worker_table_data()->CopyFrom(*result);
+ }
+ RAY_LOG(DEBUG)
+ << "Finished getting worker info, worker id = " << worker_id;
+ GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
+ });
}
void GcsWorkerManager::HandleGetAllWorkerInfo(
@@ -190,5 +205,26 @@ void GcsWorkerManager::AddWorkerDeadListener(
worker_dead_listeners_.emplace_back(std::move(listener));
}
+void GcsWorkerManager::GetWorkerInfo(
+ const WorkerID &worker_id,
+ std::function &)> callback) const {
+ auto on_done = [worker_id, callback = std::move(callback)](
+ const Status &status,
+ const boost::optional &result) {
+ if (!status.ok()) {
+ RAY_LOG(WARNING) << "Failed to get worker info, worker id = " << worker_id
+ << ", status = " << status;
+ callback(boost::none);
+ } else {
+ callback(result);
+ }
+ };
+
+ Status status = gcs_table_storage_->WorkerTable().Get(worker_id, on_done);
+ if (!status.ok()) {
+ on_done(status, boost::none);
+ }
+}
+
} // namespace gcs
} // namespace ray
diff --git a/src/ray/gcs/gcs_server/gcs_worker_manager.h b/src/ray/gcs/gcs_server/gcs_worker_manager.h
index df0fca359c4c3..a2b4500df454b 100644
--- a/src/ray/gcs/gcs_server/gcs_worker_manager.h
+++ b/src/ray/gcs/gcs_server/gcs_worker_manager.h
@@ -55,6 +55,10 @@ class GcsWorkerManager : public rpc::WorkerInfoHandler {
}
private:
+ void GetWorkerInfo(
+ const WorkerID &worker_id,
+ std::function &)> callback) const;
+
std::shared_ptr gcs_table_storage_;
std::shared_ptr gcs_publisher_;
UsageStatsClient *usage_stats_client_;
diff --git a/src/ray/gcs/gcs_server/test/gcs_kv_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_kv_manager_test.cc
index 9c4d69120a6d3..00a3725e9d579 100644
--- a/src/ray/gcs/gcs_server/test/gcs_kv_manager_test.cc
+++ b/src/ray/gcs/gcs_server/test/gcs_kv_manager_test.cc
@@ -34,8 +34,6 @@ class GcsKVManagerTest : public ::testing::TestWithParam {
ray::gcs::RedisClientOptions redis_client_options(
"127.0.0.1", ray::TEST_REDIS_SERVER_PORTS.front(), "", false);
if (GetParam() == "redis") {
- kv_instance = std::make_unique(redis_client_options);
- } else if (GetParam() == "redis_client") {
auto client = std::make_shared(redis_client_options);
RAY_CHECK_OK(client->Connect(io_service));
kv_instance = std::make_unique(
@@ -107,7 +105,7 @@ TEST_P(GcsKVManagerTest, TestInternalKV) {
INSTANTIATE_TEST_SUITE_P(GcsKVManagerTestFixture,
GcsKVManagerTest,
- ::testing::Values("redis", "redis_client", "memory"));
+ ::testing::Values("redis", "memory"));
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
diff --git a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc
index 8411ebbbca859..cabad9872701a 100644
--- a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc
+++ b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc
@@ -63,6 +63,7 @@ class GcsServerTest : public ::testing::Test {
gcs_server_->Stop();
thread_io_service_->join();
gcs_server_.reset();
+ ray::gcs::RedisCallbackManager::instance().Clear();
}
bool AddJob(const rpc::AddJobRequest &request) {
diff --git a/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc
index 9809b0a4feaff..d106f25c7c1d0 100644
--- a/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc
+++ b/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc
@@ -30,7 +30,8 @@ class GcsTaskManagerTest : public ::testing::Test {
RayConfig::instance().initialize(
R"(
{
- "task_events_max_num_task_in_gcs": 1000
+ "task_events_max_num_task_in_gcs": 1000,
+ "gcs_mark_task_failed_on_job_done_delay_ms": 100
}
)");
}
@@ -61,6 +62,21 @@ class GcsTaskManagerTest : public ::testing::Test {
}
}
+ void SyncAddTaskEvent(
+ const std::vector &tasks,
+ const std::vector> &status_timestamps,
+ const TaskID &parent_task_id = TaskID::Nil(),
+ int job_id = 0) {
+ auto events = GenTaskEvents(tasks,
+ /* attempt_number */ 0,
+ /* job_id */ job_id,
+ /* profile event */ absl::nullopt,
+ GenStateUpdate(status_timestamps),
+ GenTaskInfo(JobID::FromInt(job_id), parent_task_id));
+ auto events_data = Mocker::GenTaskEventsData(events);
+ SyncAddTaskEventData(events_data);
+ }
+
rpc::AddTaskEventDataReply SyncAddTaskEventData(const rpc::TaskEventData &events_data) {
rpc::AddTaskEventDataRequest request;
rpc::AddTaskEventDataReply reply;
@@ -449,38 +465,13 @@ TEST_F(GcsTaskManagerTest, TestFailingParentFailChildren) {
auto child2 = task_ids[2];
// Parent task running
- {
- auto events = GenTaskEvents({parent},
- /* attempt_number */ 0,
- /* job_id */ 0,
- /* profile event */ absl::nullopt,
- GenStateUpdate({{rpc::TaskStatus::RUNNING, 1}}));
- auto events_data = Mocker::GenTaskEventsData(events);
- SyncAddTaskEventData(events_data);
- }
+ SyncAddTaskEvent({parent}, {{rpc::TaskStatus::RUNNING, 1}});
// Child tasks running
- {
- auto events = GenTaskEvents({child1, child2},
- /* attempt_number */ 0,
- /* job_id */ 0,
- /* profile event */ absl::nullopt,
- GenStateUpdate({{rpc::TaskStatus::RUNNING, 2}}),
- GenTaskInfo(/* job_id */ JobID::FromInt(0), parent));
- auto events_data = Mocker::GenTaskEventsData(events);
- SyncAddTaskEventData(events_data);
- }
+ SyncAddTaskEvent({child1, child2}, {{rpc::TaskStatus::RUNNING, 2}}, parent);
// Parent task failed
- {
- auto events = GenTaskEvents({parent},
- /* attempt_number */ 0,
- /* job_id */ 0,
- /* profile event */ absl::nullopt,
- GenStateUpdate({{rpc::TaskStatus::FAILED, 3}}));
- auto events_data = Mocker::GenTaskEventsData(events);
- SyncAddTaskEventData(events_data);
- }
+ SyncAddTaskEvent({parent}, {{rpc::TaskStatus::FAILED, 3}});
// Get all children task events should be failed
{
@@ -502,38 +493,13 @@ TEST_F(GcsTaskManagerTest, TestFailedParentShouldFailGrandChildren) {
auto grand_child2 = task_ids[3];
// Parent task running
- {
- auto events = GenTaskEvents({parent},
- /* attempt_number */ 0,
- /* job_id */ 0,
- /* profile event */ absl::nullopt,
- GenStateUpdate({{rpc::TaskStatus::RUNNING, 1}}));
- auto events_data = Mocker::GenTaskEventsData(events);
- SyncAddTaskEventData(events_data);
- }
+ SyncAddTaskEvent({parent}, {{rpc::TaskStatus::RUNNING, 1}});
// Grandchild tasks running
- {
- auto events = GenTaskEvents({grand_child1, grand_child2},
- /* attempt_number */ 0,
- /* job_id */ 0,
- /* profile event */ absl::nullopt,
- GenStateUpdate({{rpc::TaskStatus::RUNNING, 3}}),
- GenTaskInfo(/* job_id */ JobID::FromInt(0), child));
- auto events_data = Mocker::GenTaskEventsData(events);
- SyncAddTaskEventData(events_data);
- }
+ SyncAddTaskEvent({grand_child1, grand_child2}, {{rpc::TaskStatus::RUNNING, 3}}, child);
// Parent task failed
- {
- auto events = GenTaskEvents({parent},
- /* attempt_number */ 0,
- /* job_id */ 0,
- /* profile event */ absl::nullopt,
- GenStateUpdate({{rpc::TaskStatus::FAILED, 4}}));
- auto events_data = Mocker::GenTaskEventsData(events);
- SyncAddTaskEventData(events_data);
- }
+ SyncAddTaskEvent({parent}, {{rpc::TaskStatus::FAILED, 4}});
// Get grand child should still be running since the parent-grand-child relationship is
// not recorded yet.
@@ -546,16 +512,7 @@ TEST_F(GcsTaskManagerTest, TestFailedParentShouldFailGrandChildren) {
}
// Child task reported running.
- {
- auto events = GenTaskEvents({child},
- /* attempt_number */ 0,
- /* job_id */ 0,
- /* profile event */ absl::nullopt,
- GenStateUpdate({{rpc::TaskStatus::RUNNING, 2}}),
- GenTaskInfo(/* job_id */ JobID::FromInt(0), parent));
- auto events_data = Mocker::GenTaskEventsData(events);
- SyncAddTaskEventData(events_data);
- }
+ SyncAddTaskEvent({child}, {{rpc::TaskStatus::RUNNING, 2}}, parent);
// Both child and grand-child should report failure since their ancestor fail.
// i.e. Child task should mark grandchildren failed.
@@ -568,6 +525,76 @@ TEST_F(GcsTaskManagerTest, TestFailedParentShouldFailGrandChildren) {
}
}
+TEST_F(GcsTaskManagerTest, TestJobFinishesFailAllRunningTasks) {
+ auto tasks_running_job1 = GenTaskIDs(10);
+ auto tasks_finished_job1 = GenTaskIDs(10);
+ auto tasks_failed_job1 = GenTaskIDs(10);
+
+ auto tasks_running_job2 = GenTaskIDs(5);
+
+ SyncAddTaskEvent(tasks_running_job1, {{rpc::TaskStatus::RUNNING, 1}}, TaskID::Nil(), 1);
+ SyncAddTaskEvent(
+ tasks_finished_job1, {{rpc::TaskStatus::FINISHED, 2}}, TaskID::Nil(), 1);
+ SyncAddTaskEvent(tasks_failed_job1, {{rpc::TaskStatus::FAILED, 3}}, TaskID::Nil(), 1);
+
+ SyncAddTaskEvent(tasks_running_job2, {{rpc::TaskStatus::RUNNING, 4}}, TaskID::Nil(), 2);
+
+ task_manager->OnJobFinished(JobID::FromInt(1), 5); // in ms
+
+ // Wait for longer than the default timer
+ boost::asio::io_service io;
+ boost::asio::deadline_timer timer(
+ io,
+ boost::posix_time::milliseconds(
+ 2 * RayConfig::instance().gcs_mark_task_failed_on_job_done_delay_ms()));
+ timer.wait();
+
+ // Running tasks from job1 failed at 5
+ {
+ absl::flat_hash_set tasks(tasks_running_job1.begin(),
+ tasks_running_job1.end());
+ auto reply = SyncGetTaskEvents(tasks);
+ EXPECT_EQ(reply.events_by_task_size(), 10);
+ for (const auto &task_event : reply.events_by_task()) {
+ EXPECT_EQ(task_event.state_updates().failed_ts(), 5000);
+ }
+ }
+
+ // Finished tasks from job1 remain finished
+ {
+ absl::flat_hash_set tasks(tasks_finished_job1.begin(),
+ tasks_finished_job1.end());
+ auto reply = SyncGetTaskEvents(tasks);
+ EXPECT_EQ(reply.events_by_task_size(), 10);
+ for (const auto &task_event : reply.events_by_task()) {
+ EXPECT_EQ(task_event.state_updates().finished_ts(), 2);
+ EXPECT_FALSE(task_event.state_updates().has_failed_ts());
+ }
+ }
+
+ // Failed tasks from job1 failed timestamp not overriden
+ {
+ absl::flat_hash_set tasks(tasks_failed_job1.begin(), tasks_failed_job1.end());
+ auto reply = SyncGetTaskEvents(tasks);
+ EXPECT_EQ(reply.events_by_task_size(), 10);
+ for (const auto &task_event : reply.events_by_task()) {
+ EXPECT_EQ(task_event.state_updates().failed_ts(), 3);
+ }
+ }
+
+ // Tasks from job2 should not be affected.
+ {
+ absl::flat_hash_set tasks(tasks_running_job2.begin(),
+ tasks_running_job2.end());
+ auto reply = SyncGetTaskEvents(tasks);
+ EXPECT_EQ(reply.events_by_task_size(), 5);
+ for (const auto &task_event : reply.events_by_task()) {
+ EXPECT_FALSE(task_event.state_updates().has_failed_ts());
+ EXPECT_FALSE(task_event.state_updates().has_finished_ts());
+ }
+ }
+}
+
TEST_F(GcsTaskManagerMemoryLimitedTest, TestIndexNoLeak) {
size_t num_limit = 100; // synced with test config
size_t num_total = 1000;
diff --git a/src/ray/gcs/pb_util.h b/src/ray/gcs/pb_util.h
index a7fc7624219b0..e799faedf47bf 100644
--- a/src/ray/gcs/pb_util.h
+++ b/src/ray/gcs/pb_util.h
@@ -105,24 +105,20 @@ inline std::shared_ptr CreateActorTableData(
/// Helper function to produce worker failure data.
inline std::shared_ptr CreateWorkerFailureData(
- const NodeID &raylet_id,
const WorkerID &worker_id,
- const std::string &address,
- int32_t port,
int64_t timestamp,
rpc::WorkerExitType disconnect_type,
const std::string &disconnect_detail,
int pid,
const rpc::RayException *creation_task_exception = nullptr) {
auto worker_failure_info_ptr = std::make_shared();
- worker_failure_info_ptr->mutable_worker_address()->set_raylet_id(raylet_id.Binary());
+ // Only report the worker id + delta (new data upon worker failures).
+ // GCS will merge the data with original worker data.
worker_failure_info_ptr->mutable_worker_address()->set_worker_id(worker_id.Binary());
- worker_failure_info_ptr->mutable_worker_address()->set_ip_address(address);
- worker_failure_info_ptr->mutable_worker_address()->set_port(port);
worker_failure_info_ptr->set_timestamp(timestamp);
worker_failure_info_ptr->set_exit_type(disconnect_type);
worker_failure_info_ptr->set_exit_detail(disconnect_detail);
- worker_failure_info_ptr->set_pid(pid);
+ worker_failure_info_ptr->set_end_time_ms(current_sys_time_ms());
if (creation_task_exception != nullptr) {
// this pointer will be freed by protobuf internal codes
auto copied_data = new rpc::RayException(*creation_task_exception);
diff --git a/src/ray/gcs/redis_client.h b/src/ray/gcs/redis_client.h
index 90c6a92267353..ccf7a43b55fbb 100644
--- a/src/ray/gcs/redis_client.h
+++ b/src/ray/gcs/redis_client.h
@@ -20,14 +20,13 @@
#include "ray/common/asio/instrumented_io_context.h"
#include "ray/common/status.h"
#include "ray/gcs/asio.h"
+#include "ray/gcs/redis_context.h"
#include "ray/util/logging.h"
namespace ray {
namespace gcs {
-class RedisContext;
-
class RedisClientOptions {
public:
RedisClientOptions(const std::string &ip,
diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc
index 1f964543806dd..a16e35f90d48b 100644
--- a/src/ray/gcs/redis_context.cc
+++ b/src/ray/gcs/redis_context.cc
@@ -212,6 +212,11 @@ std::shared_ptr RedisCallbackManager::GetCal
return it->second;
}
+void RedisCallbackManager::Clear() {
+ std::lock_guard lock(mutex_);
+ callback_items_.clear();
+}
+
void RedisCallbackManager::RemoveCallback(int64_t callback_index) {
std::lock_guard lock(mutex_);
callback_items_.erase(callback_index);
diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h
index 7c4f350219f51..f3fc9328a1f18 100644
--- a/src/ray/gcs/redis_context.h
+++ b/src/ray/gcs/redis_context.h
@@ -145,6 +145,9 @@ class RedisCallbackManager {
/// Get a callback.
std::shared_ptr GetCallback(int64_t callback_index) const;
+ /// Clear all callbacks.
+ void Clear();
+
private:
RedisCallbackManager() : num_callbacks_(0){};
diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto
index 0c5b10a437826..e471dc4adaff0 100644
--- a/src/ray/protobuf/common.proto
+++ b/src/ray/protobuf/common.proto
@@ -430,6 +430,10 @@ message TaskInfoEntry {
// If the task type is Actor creation task or Actor task
// this is set. Otherwise, it is empty.
optional bytes actor_id = 25;
+ // The placement group id of this task.
+ // If the task/actor is created within a placement group,
+ // this value is configured.
+ optional bytes placement_group_id = 26;
}
message Bundle {
diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto
index 07cfde4f7e567..b008fe9593e63 100644
--- a/src/ray/protobuf/gcs.proto
+++ b/src/ray/protobuf/gcs.proto
@@ -154,6 +154,8 @@ message ActorTableData {
// so we have a separate field to track this.
// If the actor is restarting, the node id could be incorrect.
optional bytes node_id = 29;
+ // Placement group ID if the actor requires a placement group.
+ optional bytes placement_group_id = 30;
}
message ErrorTableData {
@@ -212,6 +214,8 @@ message TaskStateUpdate {
optional int64 finished_ts = 6;
// Timestamp when status changes to FAILED.
optional int64 failed_ts = 7;
+ // Worker that runs the task.
+ optional bytes worker_id = 8;
}
// Represents events and state changes from a single task run.
@@ -285,14 +289,16 @@ message GcsNodeInfo {
// The port at which the node will expose metrics to.
int32 metrics_export_port = 9;
- // Timestamp that the node is dead.
- int64 timestamp = 10;
// The total resources of this node.
map resources_total = 11;
// The user-provided identifier or name for this node.
string node_name = 12;
+ // The unix ms timestamp the node was started at.
+ uint64 start_time_ms = 23;
+ // The unix ms timestamp the node was ended at.
+ uint64 end_time_ms = 24;
}
message JobTableData {
@@ -336,6 +342,10 @@ message WorkerTableData {
optional string exit_detail = 20;
// pid of the worker process.
uint32 pid = 21;
+ // The unix ms timestamp the worker was started at.
+ uint64 start_time_ms = 23;
+ // The unix ms timestamp the worker was ended at.
+ uint64 end_time_ms = 24;
}
// Fields to publish when worker fails.
diff --git a/src/ray/raylet/local_task_manager.cc b/src/ray/raylet/local_task_manager.cc
index 2f2a2f7efcb91..4b4fd1452750a 100644
--- a/src/ray/raylet/local_task_manager.cc
+++ b/src/ray/raylet/local_task_manager.cc
@@ -334,7 +334,7 @@ void LocalTaskManager::SpillWaitingTasks() {
if (!task.GetTaskSpecification().IsSpreadSchedulingStrategy()) {
scheduling_node_id = cluster_resource_scheduler_->GetBestSchedulableNode(
task.GetTaskSpecification(),
- /*prioritize_local_node*/ true,
+ /*preferred_node_id*/ self_node_id_.Binary(),
/*exclude_local_node*/ task_dependencies_blocked,
/*requires_object_store_memory*/ true,
&is_infeasible);
@@ -379,7 +379,7 @@ bool LocalTaskManager::TrySpillback(const std::shared_ptr &work,
// We should prefer to stay local if possible
// to avoid unnecessary spillback
// since this node is already selected by the cluster scheduler.
- /*prioritize_local_node*/ true,
+ /*preferred_node_id*/ self_node_id_.Binary(),
/*exclude_local_node*/ false,
/*requires_object_store_memory*/ false,
&is_infeasible);
@@ -1023,7 +1023,12 @@ ResourceRequest LocalTaskManager::CalcNormalTaskResources() const {
}
if (auto allocated_instances = worker->GetAllocatedInstances()) {
- total_normal_task_resources += allocated_instances->ToResourceRequest();
+ auto resource_request = allocated_instances->ToResourceRequest();
+ // Blocked normal task workers have temporarily released its allocated CPU.
+ if (worker->IsBlocked()) {
+ resource_request.Set(ResourceID::CPU(), 0);
+ }
+ total_normal_task_resources += resource_request;
}
}
return total_normal_task_resources;
diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc
index 5ca8c5b415890..ad019b5dc233f 100644
--- a/src/ray/raylet/node_manager.cc
+++ b/src/ray/raylet/node_manager.cc
@@ -1457,10 +1457,7 @@ void NodeManager::DisconnectClient(const std::shared_ptr &clie
}
// Publish the worker failure.
auto worker_failure_data_ptr =
- gcs::CreateWorkerFailureData(self_node_id_,
- worker->WorkerId(),
- worker->IpAddress(),
- worker->Port(),
+ gcs::CreateWorkerFailureData(worker->WorkerId(),
time(nullptr),
disconnect_type,
disconnect_detail,
diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc
index e7171661bd709..0b552167921b6 100644
--- a/src/ray/raylet/raylet.cc
+++ b/src/ray/raylet/raylet.cc
@@ -91,6 +91,7 @@ Raylet::Raylet(instrumented_io_context &main_service,
auto resource_map = node_manager_config.resource_config.ToResourceMap();
self_node_info_.mutable_resources_total()->insert(resource_map.begin(),
resource_map.end());
+ self_node_info_.set_start_time_ms(current_sys_time_ms());
}
Raylet::~Raylet() {}
diff --git a/src/ray/raylet/scheduling/cluster_resource_manager.h b/src/ray/raylet/scheduling/cluster_resource_manager.h
index c41b9b31a6a91..1db8fcf0d3256 100644
--- a/src/ray/raylet/scheduling/cluster_resource_manager.h
+++ b/src/ray/raylet/scheduling/cluster_resource_manager.h
@@ -168,6 +168,7 @@ class ClusterResourceManager {
FRIEND_TEST(ClusterResourceSchedulerTest, SchedulingAddOrUpdateNodeTest);
FRIEND_TEST(ClusterResourceSchedulerTest, NodeAffinitySchedulingStrategyTest);
FRIEND_TEST(ClusterResourceSchedulerTest, SpreadSchedulingStrategyTest);
+ FRIEND_TEST(ClusterResourceSchedulerTest, SchedulingWithPreferredNodeTest);
FRIEND_TEST(ClusterResourceSchedulerTest, SchedulingResourceRequestTest);
FRIEND_TEST(ClusterResourceSchedulerTest, SchedulingUpdateTotalResourcesTest);
FRIEND_TEST(ClusterResourceSchedulerTest,
diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc
index 3fad6456f2222..35afa17f94b13 100644
--- a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc
+++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc
@@ -123,6 +123,7 @@ scheduling::NodeID ClusterResourceScheduler::GetBestSchedulableNode(
const rpc::SchedulingStrategy &scheduling_strategy,
bool actor_creation,
bool force_spillback,
+ const std::string &preferred_node_id,
int64_t *total_violations,
bool *is_infeasible) {
// The zero cpu actor is a special case that must be handled the same way by all
@@ -168,7 +169,8 @@ scheduling::NodeID ClusterResourceScheduler::GetBestSchedulableNode(
scheduling_policy_->Schedule(resource_request,
SchedulingOptions::Hybrid(
/*avoid_local_node*/ force_spillback,
- /*require_node_available*/ force_spillback));
+ /*require_node_available*/ force_spillback,
+ preferred_node_id));
}
*is_infeasible = best_node_id.IsNil();
@@ -192,6 +194,7 @@ scheduling::NodeID ClusterResourceScheduler::GetBestSchedulableNode(
bool requires_object_store_memory,
bool actor_creation,
bool force_spillback,
+ const std::string &preferred_node_id,
int64_t *total_violations,
bool *is_infeasible) {
ResourceRequest resource_request =
@@ -200,6 +203,7 @@ scheduling::NodeID ClusterResourceScheduler::GetBestSchedulableNode(
scheduling_strategy,
actor_creation,
force_spillback,
+ preferred_node_id,
total_violations,
is_infeasible);
}
@@ -244,13 +248,13 @@ bool ClusterResourceScheduler::IsSchedulableOnNode(
scheduling::NodeID ClusterResourceScheduler::GetBestSchedulableNode(
const TaskSpecification &task_spec,
- bool prioritize_local_node,
+ const std::string &preferred_node_id,
bool exclude_local_node,
bool requires_object_store_memory,
bool *is_infeasible) {
// If the local node is available, we should directly return it instead of
// going through the full hybrid policy since we don't want spillback.
- if (prioritize_local_node && !exclude_local_node &&
+ if (preferred_node_id == local_node_id_.Binary() && !exclude_local_node &&
IsSchedulableOnNode(local_node_id_,
task_spec.GetRequiredResources().GetResourceMap(),
requires_object_store_memory)) {
@@ -266,6 +270,7 @@ scheduling::NodeID ClusterResourceScheduler::GetBestSchedulableNode(
requires_object_store_memory,
task_spec.IsActorCreationTask(),
exclude_local_node,
+ preferred_node_id,
&_unused,
is_infeasible);
@@ -276,7 +281,7 @@ scheduling::NodeID ClusterResourceScheduler::GetBestSchedulableNode(
requires_object_store_memory)) {
// Prefer waiting on the local node since the local node is chosen for a reason (e.g.
// spread).
- if (prioritize_local_node) {
+ if (preferred_node_id == local_node_id_.Binary()) {
*is_infeasible = false;
return local_node_id_;
}
diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.h b/src/ray/raylet/scheduling/cluster_resource_scheduler.h
index 9c5b2e96ef11a..5e48369b91720 100644
--- a/src/ray/raylet/scheduling/cluster_resource_scheduler.h
+++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.h
@@ -79,7 +79,8 @@ class ClusterResourceScheduler {
/// In hybrid mode, see `scheduling_policy.h` for a description of the policy.
///
/// \param task_spec: Task/Actor to be scheduled.
- /// \param prioritize_local_node: true if we want to try out local node first.
+ /// \param preferred_node_id: the node where the task is preferred to be placed. An
+ /// empty `preferred_node_id` (string) means no preferred node.
/// \param exclude_local_node: true if we want to avoid local node. This will cancel
/// prioritize_local_node if set to true.
/// \param requires_object_store_memory: take object store memory usage as part of
@@ -90,7 +91,7 @@ class ClusterResourceScheduler {
/// \return emptry string, if no node can schedule the current request; otherwise,
/// return the string name of a node that can schedule the resource request.
scheduling::NodeID GetBestSchedulableNode(const TaskSpecification &task_spec,
- bool prioritize_local_node,
+ const std::string &preferred_node_id,
bool exclude_local_node,
bool requires_object_store_memory,
bool *is_infeasible);
@@ -159,10 +160,11 @@ class ClusterResourceScheduler {
/// \param scheduling_strategy: Strategy about how to schedule this task.
/// \param actor_creation: True if this is an actor creation task.
/// \param force_spillback: True if we want to avoid local node.
- /// \param violations: The number of soft constraint violations associated
+ /// \param preferred_node_id: The node where the task is preferred to be placed.
+ /// \param violations[out]: The number of soft constraint violations associated
/// with the node returned by this function (assuming
/// a node that can schedule resource_request is found).
- /// \param is_infeasible[in]: It is set true if the task is not schedulable because it
+ /// \param is_infeasible[out]: It is set true if the task is not schedulable because it
/// is infeasible.
///
/// \return -1, if no node can schedule the current request; otherwise,
@@ -172,6 +174,7 @@ class ClusterResourceScheduler {
const rpc::SchedulingStrategy &scheduling_strategy,
bool actor_creation,
bool force_spillback,
+ const std::string &preferred_node_id,
int64_t *violations,
bool *is_infeasible);
@@ -187,6 +190,7 @@ class ClusterResourceScheduler {
bool requires_object_store_memory,
bool actor_creation,
bool force_spillback,
+ const std::string &preferred_node_id,
int64_t *violations,
bool *is_infeasible);
@@ -216,6 +220,7 @@ class ClusterResourceScheduler {
FRIEND_TEST(ClusterResourceSchedulerTest, SchedulingAddOrUpdateNodeTest);
FRIEND_TEST(ClusterResourceSchedulerTest, NodeAffinitySchedulingStrategyTest);
FRIEND_TEST(ClusterResourceSchedulerTest, SpreadSchedulingStrategyTest);
+ FRIEND_TEST(ClusterResourceSchedulerTest, SchedulingWithPreferredNodeTest);
FRIEND_TEST(ClusterResourceSchedulerTest, SchedulingResourceRequestTest);
FRIEND_TEST(ClusterResourceSchedulerTest, SchedulingUpdateTotalResourcesTest);
FRIEND_TEST(ClusterResourceSchedulerTest,
diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc b/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc
index 3c5ecf450413c..f1e2c49dd54fb 100644
--- a/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc
+++ b/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc
@@ -310,6 +310,7 @@ TEST_F(ClusterResourceSchedulerTest, NodeAffinitySchedulingStrategyTest) {
false,
false,
false,
+ std::string(),
&violations,
&is_infeasible);
ASSERT_EQ(node_id_1, remote_node_id);
@@ -322,6 +323,7 @@ TEST_F(ClusterResourceSchedulerTest, NodeAffinitySchedulingStrategyTest) {
false,
false,
false,
+ std::string(),
&violations,
&is_infeasible);
ASSERT_EQ(node_id_2, local_node_id);
@@ -334,6 +336,7 @@ TEST_F(ClusterResourceSchedulerTest, NodeAffinitySchedulingStrategyTest) {
false,
false,
false,
+ std::string(),
&violations,
&is_infeasible);
ASSERT_TRUE(node_id_3.IsNil());
@@ -346,6 +349,7 @@ TEST_F(ClusterResourceSchedulerTest, NodeAffinitySchedulingStrategyTest) {
false,
false,
false,
+ std::string(),
&violations,
&is_infeasible);
ASSERT_EQ(node_id_4, local_node_id);
@@ -371,6 +375,7 @@ TEST_F(ClusterResourceSchedulerTest, SpreadSchedulingStrategyTest) {
false,
false,
false,
+ std::string(),
&violations,
&is_infeasible);
absl::flat_hash_map resource_available({{"CPU", 9}});
@@ -381,12 +386,52 @@ TEST_F(ClusterResourceSchedulerTest, SpreadSchedulingStrategyTest) {
false,
false,
false,
+ std::string(),
&violations,
&is_infeasible);
ASSERT_EQ((std::set{node_id_1, node_id_2}),
(std::set{local_node_id, remote_node_id}));
}
+TEST_F(ClusterResourceSchedulerTest, SchedulingWithPreferredNodeTest) {
+ absl::flat_hash_map resource_total({{"CPU", 10}});
+ auto local_node_id = scheduling::NodeID(NodeID::FromRandom().Binary());
+ ClusterResourceScheduler resource_scheduler(
+ local_node_id, resource_total, is_node_available_fn_);
+ AssertPredefinedNodeResources();
+ auto remote_node_id = scheduling::NodeID(NodeID::FromRandom().Binary());
+ resource_scheduler.GetClusterResourceManager().AddOrUpdateNode(
+ remote_node_id, resource_total, resource_total);
+
+ absl::flat_hash_map resource_request({{"CPU", 5}});
+ int64_t violations;
+ bool is_infeasible;
+ rpc::SchedulingStrategy scheduling_strategy;
+ scheduling_strategy.mutable_default_scheduling_strategy();
+ // Select node with the remote node preferred.
+ auto node_id_1 = resource_scheduler.GetBestSchedulableNode(resource_request,
+ scheduling_strategy,
+ false,
+ false,
+ false,
+ remote_node_id.Binary(),
+ &violations,
+ &is_infeasible);
+
+ // If no preferred node specified, then still prefer the local one.
+ auto node_id_2 = resource_scheduler.GetBestSchedulableNode(resource_request,
+ scheduling_strategy,
+ false,
+ false,
+ false,
+ std::string(),
+ &violations,
+ &is_infeasible);
+
+ ASSERT_EQ((std::set{node_id_1, node_id_2}),
+ (std::set{remote_node_id, local_node_id}));
+}
+
TEST_F(ClusterResourceSchedulerTest, SchedulingUpdateAvailableResourcesTest) {
// Create cluster resources.
NodeResources node_resources = CreateNodeResources({{ResourceID::CPU(), 10},
@@ -408,8 +453,13 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingUpdateAvailableResourcesTest) {
bool is_infeasible;
rpc::SchedulingStrategy scheduling_strategy;
scheduling_strategy.mutable_default_scheduling_strategy();
- auto node_id = resource_scheduler.GetBestSchedulableNode(
- resource_request, scheduling_strategy, false, false, &violations, &is_infeasible);
+ auto node_id = resource_scheduler.GetBestSchedulableNode(resource_request,
+ scheduling_strategy,
+ false,
+ false,
+ std::string(),
+ &violations,
+ &is_infeasible);
ASSERT_EQ(node_id.ToInt(), 1);
ASSERT_TRUE(violations == 0);
@@ -527,8 +577,13 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingResourceRequestTest) {
ResourceRequest resource_request = CreateResourceRequest({{ResourceID::CPU(), 11}});
int64_t violations;
bool is_infeasible;
- auto node_id = resource_scheduler.GetBestSchedulableNode(
- resource_request, scheduling_strategy, false, false, &violations, &is_infeasible);
+ auto node_id = resource_scheduler.GetBestSchedulableNode(resource_request,
+ scheduling_strategy,
+ false,
+ false,
+ std::string(),
+ &violations,
+ &is_infeasible);
ASSERT_TRUE(node_id.IsNil());
}
@@ -537,8 +592,13 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingResourceRequestTest) {
ResourceRequest resource_request = CreateResourceRequest({{ResourceID::CPU(), 5}});
int64_t violations;
bool is_infeasible;
- auto node_id = resource_scheduler.GetBestSchedulableNode(
- resource_request, scheduling_strategy, false, false, &violations, &is_infeasible);
+ auto node_id = resource_scheduler.GetBestSchedulableNode(resource_request,
+ scheduling_strategy,
+ false,
+ false,
+ std::string(),
+ &violations,
+ &is_infeasible);
ASSERT_TRUE(!node_id.IsNil());
ASSERT_TRUE(violations == 0);
}
@@ -548,8 +608,13 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingResourceRequestTest) {
{{ResourceID::CPU(), 5}, {ResourceID::Memory(), 2}, {ResourceID("custom1"), 11}});
int64_t violations;
bool is_infeasible;
- auto node_id = resource_scheduler.GetBestSchedulableNode(
- resource_request, scheduling_strategy, false, false, &violations, &is_infeasible);
+ auto node_id = resource_scheduler.GetBestSchedulableNode(resource_request,
+ scheduling_strategy,
+ false,
+ false,
+ std::string(),
+ &violations,
+ &is_infeasible);
ASSERT_TRUE(node_id.IsNil());
}
// Custom resources, no constraint violation.
@@ -558,8 +623,13 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingResourceRequestTest) {
{{ResourceID::CPU(), 5}, {ResourceID::Memory(), 2}, {ResourceID("custom1"), 5}});
int64_t violations;
bool is_infeasible;
- auto node_id = resource_scheduler.GetBestSchedulableNode(
- resource_request, scheduling_strategy, false, false, &violations, &is_infeasible);
+ auto node_id = resource_scheduler.GetBestSchedulableNode(resource_request,
+ scheduling_strategy,
+ false,
+ false,
+ std::string(),
+ &violations,
+ &is_infeasible);
ASSERT_TRUE(!node_id.IsNil());
ASSERT_TRUE(violations == 0);
}
@@ -571,8 +641,13 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingResourceRequestTest) {
{ResourceID("custom100"), 5}});
int64_t violations;
bool is_infeasible;
- auto node_id = resource_scheduler.GetBestSchedulableNode(
- resource_request, scheduling_strategy, false, false, &violations, &is_infeasible);
+ auto node_id = resource_scheduler.GetBestSchedulableNode(resource_request,
+ scheduling_strategy,
+ false,
+ false,
+ std::string(),
+ &violations,
+ &is_infeasible);
ASSERT_TRUE(node_id.IsNil());
}
// Placement hints, no constraint violation.
@@ -581,8 +656,13 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingResourceRequestTest) {
{{ResourceID::CPU(), 5}, {ResourceID::Memory(), 2}, {ResourceID("custom1"), 5}});
int64_t violations;
bool is_infeasible;
- auto node_id = resource_scheduler.GetBestSchedulableNode(
- resource_request, scheduling_strategy, false, false, &violations, &is_infeasible);
+ auto node_id = resource_scheduler.GetBestSchedulableNode(resource_request,
+ scheduling_strategy,
+ false,
+ false,
+ std::string(),
+ &violations,
+ &is_infeasible);
ASSERT_TRUE(!node_id.IsNil());
ASSERT_TRUE(violations == 0);
}
@@ -888,6 +968,7 @@ TEST_F(ClusterResourceSchedulerTest, DeadNodeTest) {
false,
false,
false,
+ std::string(),
&violations,
&is_infeasible));
EXPECT_CALL(*gcs_client_->mock_node_accessor, Get(node_id, ::testing::_))
@@ -899,6 +980,7 @@ TEST_F(ClusterResourceSchedulerTest, DeadNodeTest) {
false,
false,
false,
+ std::string(),
&violations,
&is_infeasible)
.IsNil());
@@ -1093,6 +1175,7 @@ TEST_F(ClusterResourceSchedulerTest, TestAlwaysSpillInfeasibleTask) {
false,
false,
false,
+ std::string(),
&total_violations,
&is_infeasible)
.IsNil());
@@ -1108,6 +1191,7 @@ TEST_F(ClusterResourceSchedulerTest, TestAlwaysSpillInfeasibleTask) {
false,
false,
false,
+ std::string(),
&total_violations,
&is_infeasible));
@@ -1122,6 +1206,7 @@ TEST_F(ClusterResourceSchedulerTest, TestAlwaysSpillInfeasibleTask) {
false,
false,
false,
+ std::string(),
&total_violations,
&is_infeasible));
}
@@ -1339,19 +1424,29 @@ TEST_F(ClusterResourceSchedulerTest, DirtyLocalViewTest) {
resource_scheduler.GetClusterResourceManager().AddOrUpdateNode(
remote, {{"CPU", 2.}}, {{"CPU", num_slots_available}});
for (int j = 0; j < num_slots_available; j++) {
- ASSERT_EQ(
- remote,
- resource_scheduler.GetBestSchedulableNode(
- task_spec, scheduling_strategy, false, false, true, &t, &is_infeasible));
+ ASSERT_EQ(remote,
+ resource_scheduler.GetBestSchedulableNode(task_spec,
+ scheduling_strategy,
+ false,
+ false,
+ true,
+ std::string(),
+ &t,
+ &is_infeasible));
// Allocate remote resources.
ASSERT_TRUE(resource_scheduler.AllocateRemoteTaskResources(remote, task_spec));
}
// Our local view says there are not enough resources on the remote node to
// schedule another task.
- ASSERT_EQ(
- resource_scheduler.GetBestSchedulableNode(
- task_spec, scheduling_strategy, false, false, true, &t, &is_infeasible),
- scheduling::NodeID::Nil());
+ ASSERT_EQ(resource_scheduler.GetBestSchedulableNode(task_spec,
+ scheduling_strategy,
+ false,
+ false,
+ true,
+ std::string(),
+ &t,
+ &is_infeasible),
+ scheduling::NodeID::Nil());
ASSERT_FALSE(
resource_scheduler.GetLocalResourceManager().AllocateLocalTaskResources(
task_spec, task_allocation));
@@ -1371,32 +1466,62 @@ TEST_F(ClusterResourceSchedulerTest, DynamicResourceTest) {
rpc::SchedulingStrategy scheduling_strategy;
scheduling_strategy.mutable_default_scheduling_strategy();
- auto result = resource_scheduler.GetBestSchedulableNode(
- resource_request, scheduling_strategy, false, false, false, &t, &is_infeasible);
+ auto result = resource_scheduler.GetBestSchedulableNode(resource_request,
+ scheduling_strategy,
+ false,
+ false,
+ false,
+ std::string(),
+ &t,
+ &is_infeasible);
ASSERT_TRUE(result.IsNil());
resource_scheduler.GetLocalResourceManager().AddLocalResourceInstances(
scheduling::ResourceID("custom123"), {0., 1.0, 1.0});
- result = resource_scheduler.GetBestSchedulableNode(
- resource_request, scheduling_strategy, false, false, false, &t, &is_infeasible);
+ result = resource_scheduler.GetBestSchedulableNode(resource_request,
+ scheduling_strategy,
+ false,
+ false,
+ false,
+ std::string(),
+ &t,
+ &is_infeasible);
ASSERT_FALSE(result.IsNil()) << resource_scheduler.DebugString();
resource_request["custom123"] = 3;
- result = resource_scheduler.GetBestSchedulableNode(
- resource_request, scheduling_strategy, false, false, false, &t, &is_infeasible);
+ result = resource_scheduler.GetBestSchedulableNode(resource_request,
+ scheduling_strategy,
+ false,
+ false,
+ false,
+ std::string(),
+ &t,
+ &is_infeasible);
ASSERT_TRUE(result.IsNil());
resource_scheduler.GetLocalResourceManager().AddLocalResourceInstances(
scheduling::ResourceID("custom123"), {1.0});
- result = resource_scheduler.GetBestSchedulableNode(
- resource_request, scheduling_strategy, false, false, false, &t, &is_infeasible);
+ result = resource_scheduler.GetBestSchedulableNode(resource_request,
+ scheduling_strategy,
+ false,
+ false,
+ false,
+ std::string(),
+ &t,
+ &is_infeasible);
ASSERT_FALSE(result.IsNil());
resource_scheduler.GetLocalResourceManager().DeleteLocalResource(
scheduling::ResourceID("custom123"));
- result = resource_scheduler.GetBestSchedulableNode(
- resource_request, scheduling_strategy, false, false, false, &t, &is_infeasible);
+ result = resource_scheduler.GetBestSchedulableNode(resource_request,
+ scheduling_strategy,
+ false,
+ false,
+ false,
+ std::string(),
+ &t,
+ &is_infeasible);
ASSERT_TRUE(result.IsNil());
}
@@ -1436,6 +1561,7 @@ TEST_F(ClusterResourceSchedulerTest, TestForceSpillback) {
false,
false,
/*force_spillback=*/false,
+ std::string(),
&total_violations,
&is_infeasible),
scheduling::NodeID("local"));
@@ -1446,6 +1572,7 @@ TEST_F(ClusterResourceSchedulerTest, TestForceSpillback) {
false,
false,
/*force_spillback=*/true,
+ std::string(),
&total_violations,
&is_infeasible),
scheduling::NodeID::Nil());
@@ -1457,6 +1584,7 @@ TEST_F(ClusterResourceSchedulerTest, TestForceSpillback) {
false,
false,
/*force_spillback=*/true,
+ std::string(),
&total_violations,
&is_infeasible),
scheduling::NodeID::Nil());
@@ -1467,6 +1595,7 @@ TEST_F(ClusterResourceSchedulerTest, TestForceSpillback) {
false,
false,
/*force_spillback=*/true,
+ std::string(),
&total_violations,
&is_infeasible),
node_ids[51]);
@@ -1574,6 +1703,7 @@ TEST_F(ClusterResourceSchedulerTest, AffinityWithBundleScheduleTest) {
scheduling_strategy,
true,
false,
+ std::string(),
&violations,
&is_infeasible),
except_node_id);
diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc
index 64e25862aeabb..1d87aaaca0e88 100644
--- a/src/ray/raylet/scheduling/cluster_task_manager.cc
+++ b/src/ray/raylet/scheduling/cluster_task_manager.cc
@@ -97,7 +97,8 @@ void ClusterTaskManager::ScheduleAndDispatchTasks() {
<< task.GetTaskSpecification().TaskId();
auto scheduling_node_id = cluster_resource_scheduler_->GetBestSchedulableNode(
task.GetTaskSpecification(),
- work->PrioritizeLocalNode(),
+ /*preferred_node_id*/ work->PrioritizeLocalNode() ? self_node_id_.Binary()
+ : task.GetPreferredNodeID(),
/*exclude_local_node*/ false,
/*requires_object_store_memory*/ false,
&is_infeasible);
@@ -191,7 +192,8 @@ void ClusterTaskManager::TryScheduleInfeasibleTask() {
bool is_infeasible;
cluster_resource_scheduler_->GetBestSchedulableNode(
task.GetTaskSpecification(),
- work->PrioritizeLocalNode(),
+ /*preferred_node_id*/ work->PrioritizeLocalNode() ? self_node_id_.Binary()
+ : task.GetPreferredNodeID(),
/*exclude_local_node*/ false,
/*requires_object_store_memory*/ false,
&is_infeasible);
diff --git a/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.cc b/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.cc
index e23664d81e26e..0b83a52b242c5 100644
--- a/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.cc
+++ b/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.cc
@@ -28,14 +28,18 @@ scheduling::NodeID HybridSchedulingPolicy::HybridPolicyWithFilter(
float spread_threshold,
bool force_spillback,
bool require_node_available,
+ const std::string &preferred_node,
NodeFilter node_filter) {
- // Step 1: Generate the traversal order. We guarantee that the first node is local, to
- // encourage local scheduling. The rest of the traversal order should be globally
- // consistent, to encourage using "warm" workers.
+ // Step 1: Generate the traversal order. We guarantee that the first node is local (or
+ // the preferred node, if provided), to encourage local/preferred scheduling. The rest
+ // of the traversal order should be globally consistent, to encourage using "warm"
+ // workers.
std::vector round;
round.reserve(nodes_.size());
- const auto local_it = nodes_.find(local_node_id_);
- RAY_CHECK(local_it != nodes_.end());
+ auto preferred_node_id =
+ preferred_node.empty() ? local_node_id_ : scheduling::NodeID(preferred_node);
+ auto preferred_it = nodes_.find(preferred_node_id);
+ RAY_CHECK(preferred_it != nodes_.end());
auto predicate = [this, node_filter](scheduling::NodeID node_id,
const NodeResources &node_resources) {
if (!is_node_available_(node_id)) {
@@ -52,18 +56,18 @@ scheduling::NodeID HybridSchedulingPolicy::HybridPolicyWithFilter(
return !has_gpu;
};
- const auto &local_node_view = local_it->second.GetLocalView();
- // If we should include local node at all, make sure it is at the front of the list
- // so that
+ const auto &preferred_node_view = preferred_it->second.GetLocalView();
+ // If we should include local/preferred node at all, make sure it is at the front of the
+ // list so that
// 1. It's first in traversal order.
// 2. It's easy to avoid sorting it.
- if (predicate(local_node_id_, local_node_view) && !force_spillback) {
- round.push_back(local_node_id_);
+ if (predicate(preferred_node_id, preferred_node_view) && !force_spillback) {
+ round.push_back(preferred_node_id);
}
const auto start_index = round.size();
for (const auto &pair : nodes_) {
- if (pair.first != local_node_id_ &&
+ if (pair.first != preferred_node_id &&
predicate(pair.first, pair.second.GetLocalView())) {
round.push_back(pair.first);
}
@@ -144,7 +148,8 @@ scheduling::NodeID HybridSchedulingPolicy::Schedule(
return HybridPolicyWithFilter(resource_request,
options.spread_threshold,
options.avoid_local_node,
- options.require_node_available);
+ options.require_node_available,
+ options.preferred_node_id);
}
// Try schedule on non-GPU nodes.
@@ -152,6 +157,7 @@ scheduling::NodeID HybridSchedulingPolicy::Schedule(
options.spread_threshold,
options.avoid_local_node,
/*require_node_available*/ true,
+ options.preferred_node_id,
NodeFilter::kNonGpu);
if (!best_node_id.IsNil()) {
return best_node_id;
@@ -162,7 +168,8 @@ scheduling::NodeID HybridSchedulingPolicy::Schedule(
return HybridPolicyWithFilter(resource_request,
options.spread_threshold,
options.avoid_local_node,
- options.require_node_available);
+ options.require_node_available,
+ options.preferred_node_id);
}
} // namespace raylet_scheduling_policy
diff --git a/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.h b/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.h
index 4dd3a055889d6..773eaf8fcb6b0 100644
--- a/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.h
+++ b/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.h
@@ -89,6 +89,7 @@ class HybridSchedulingPolicy : public ISchedulingPolicy {
float spread_threshold,
bool force_spillback,
bool require_available,
+ const std::string &preferred_node,
NodeFilter node_filter = NodeFilter::kAny);
};
} // namespace raylet_scheduling_policy
diff --git a/src/ray/raylet/scheduling/policy/scheduling_options.h b/src/ray/raylet/scheduling/policy/scheduling_options.h
index 42b852e52020f..cd624f9b610c8 100644
--- a/src/ray/raylet/scheduling/policy/scheduling_options.h
+++ b/src/ray/raylet/scheduling/policy/scheduling_options.h
@@ -57,12 +57,17 @@ struct SchedulingOptions {
}
// construct option for hybrid scheduling policy.
- static SchedulingOptions Hybrid(bool avoid_local_node, bool require_node_available) {
+ static SchedulingOptions Hybrid(bool avoid_local_node,
+ bool require_node_available,
+ const std::string &preferred_node_id = std::string()) {
return SchedulingOptions(SchedulingType::HYBRID,
RayConfig::instance().scheduler_spread_threshold(),
avoid_local_node,
require_node_available,
- RayConfig::instance().scheduler_avoid_gpu_nodes());
+ RayConfig::instance().scheduler_avoid_gpu_nodes(),
+ /*max_cpu_fraction_per_node*/ 1.0,
+ /*scheduling_context*/ nullptr,
+ preferred_node_id);
}
static SchedulingOptions NodeAffinity(bool avoid_local_node,
@@ -152,6 +157,9 @@ struct SchedulingOptions {
std::shared_ptr scheduling_context;
std::string node_affinity_node_id;
bool node_affinity_soft = false;
+ // The node where the task is preferred to be placed. By default, this node id
+ // is empty, which means no preferred node.
+ std::string preferred_node_id;
private:
SchedulingOptions(SchedulingType type,
@@ -160,14 +168,16 @@ struct SchedulingOptions {
bool require_node_available,
bool avoid_gpu_nodes,
double max_cpu_fraction_per_node = 1.0,
- std::shared_ptr scheduling_context = nullptr)
+ std::shared_ptr scheduling_context = nullptr,
+ const std::string &preferred_node_id = std::string())
: scheduling_type(type),
spread_threshold(spread_threshold),
avoid_local_node(avoid_local_node),
require_node_available(require_node_available),
avoid_gpu_nodes(avoid_gpu_nodes),
max_cpu_fraction_per_node(max_cpu_fraction_per_node),
- scheduling_context(std::move(scheduling_context)) {}
+ scheduling_context(std::move(scheduling_context)),
+ preferred_node_id(preferred_node_id) {}
friend class ::ray::raylet::SchedulingPolicyTest;
};
diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h
index b1cc41bd3c068..c2453c88a1df8 100644
--- a/src/ray/rpc/gcs_server/gcs_rpc_server.h
+++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h
@@ -78,6 +78,8 @@ namespace rpc {
class JobInfoGcsServiceHandler {
public:
+ using JobFinishListenerCallback = std::function;
+
virtual ~JobInfoGcsServiceHandler() = default;
virtual void HandleAddJob(AddJobRequest request,
@@ -92,8 +94,7 @@ class JobInfoGcsServiceHandler {
GetAllJobInfoReply *reply,
SendReplyCallback send_reply_callback) = 0;
- virtual void AddJobFinishedListener(
- std::function)> listener) = 0;
+ virtual void AddJobFinishedListener(JobFinishListenerCallback listener) = 0;
virtual void HandleReportJobError(ReportJobErrorRequest request,
ReportJobErrorReply *reply,