Skip to content

Commit d877d27

Browse files
committed
fix: Fix Spark template to work correctly on feast init -t spark (#2393)
Signed-off-by: Danny Chiao <danny@tecton.ai>
1 parent e7a3b3f commit d877d27

File tree

2 files changed

+27
-40
lines changed

2 files changed

+27
-40
lines changed

sdk/python/feast/templates/spark/bootstrap.py

+25-38
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,35 @@
1-
from datetime import datetime, timedelta
2-
from pathlib import Path
3-
4-
from pyspark.sql import SparkSession
5-
6-
from feast.driver_test_data import (
7-
create_customer_daily_profile_df,
8-
create_driver_hourly_stats_df,
9-
)
10-
11-
CURRENT_DIR = Path(__file__).parent
12-
DRIVER_ENTITIES = [1001, 1002, 1003]
13-
CUSTOMER_ENTITIES = [201, 202, 203]
14-
START_DATE = datetime.strptime("2022-01-01", "%Y-%m-%d")
15-
END_DATE = START_DATE + timedelta(days=7)
16-
17-
181
def bootstrap():
192
# Bootstrap() will automatically be called from the init_repo() during `feast init`
20-
generate_example_data(
21-
spark_session=SparkSession.builder.getOrCreate(), base_dir=str(CURRENT_DIR),
22-
)
23-
3+
import pathlib
4+
from datetime import datetime, timedelta
245

25-
def example_data_exists(base_dir: str) -> bool:
26-
for path in [
27-
Path(base_dir) / "data" / "driver_hourly_stats",
28-
Path(base_dir) / "data" / "customer_daily_profile",
29-
]:
30-
if not path.exists():
31-
return False
32-
return True
6+
from feast.driver_test_data import (
7+
create_customer_daily_profile_df,
8+
create_driver_hourly_stats_df,
9+
)
3310

11+
repo_path = pathlib.Path(__file__).parent.absolute()
12+
data_path = repo_path / "data"
13+
data_path.mkdir(exist_ok=True)
3414

35-
def generate_example_data(spark_session: SparkSession, base_dir: str) -> None:
36-
spark_session.createDataFrame(
37-
data=create_driver_hourly_stats_df(DRIVER_ENTITIES, START_DATE, END_DATE)
38-
).write.parquet(
39-
path=str(Path(base_dir) / "data" / "driver_hourly_stats"), mode="overwrite",
15+
driver_entities = [1001, 1002, 1003]
16+
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
17+
start_date = end_date - timedelta(days=15)
18+
driver_stats_df = create_driver_hourly_stats_df(
19+
driver_entities, start_date, end_date
20+
)
21+
driver_stats_df.to_parquet(
22+
path=str(data_path / "driver_hourly_stats.parquet"),
23+
allow_truncated_timestamps=True,
4024
)
4125

42-
spark_session.createDataFrame(
43-
data=create_customer_daily_profile_df(CUSTOMER_ENTITIES, START_DATE, END_DATE)
44-
).write.parquet(
45-
path=str(Path(base_dir) / "data" / "customer_daily_profile"), mode="overwrite",
26+
customer_entities = [201, 202, 203]
27+
customer_profile_df = create_customer_daily_profile_df(
28+
customer_entities, start_date, end_date
29+
)
30+
customer_profile_df.to_parquet(
31+
path=str(data_path / "customer_daily_profile.parquet"),
32+
allow_truncated_timestamps=True,
4633
)
4734

4835

sdk/python/feast/templates/spark/example.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@
2424
# Sources
2525
driver_hourly_stats = SparkSource(
2626
name="driver_hourly_stats",
27-
path=f"{CURRENT_DIR}/data/driver_hourly_stats",
27+
path=f"{CURRENT_DIR}/data/driver_hourly_stats.parquet",
2828
file_format="parquet",
2929
event_timestamp_column="event_timestamp",
3030
created_timestamp_column="created",
3131
)
3232
customer_daily_profile = SparkSource(
3333
name="customer_daily_profile",
34-
path=f"{CURRENT_DIR}/data/customer_daily_profile",
34+
path=f"{CURRENT_DIR}/data/customer_daily_profile.parquet",
3535
file_format="parquet",
3636
event_timestamp_column="event_timestamp",
3737
created_timestamp_column="created",

0 commit comments

Comments
 (0)