@@ -571,7 +571,7 @@ def _plan(
571
571
new_infra_proto = new_infra .to_proto ()
572
572
infra_diff = diff_infra_protos (current_infra_proto , new_infra_proto )
573
573
574
- return ( registry_diff , infra_diff , new_infra )
574
+ return registry_diff , infra_diff , new_infra
575
575
576
576
@log_exceptions_and_usage
577
577
def _apply_diffs (
@@ -659,16 +659,23 @@ def apply(
659
659
]
660
660
odfvs_to_update = [ob for ob in objects if isinstance (ob , OnDemandFeatureView )]
661
661
services_to_update = [ob for ob in objects if isinstance (ob , FeatureService )]
662
- data_sources_to_update = [ob for ob in objects if isinstance (ob , DataSource )]
663
-
664
- if len (entities_to_update ) + len (views_to_update ) + len (
665
- request_views_to_update
666
- ) + len (odfvs_to_update ) + len (services_to_update ) + len (
667
- data_sources_to_update
668
- ) != len (
669
- objects
670
- ):
671
- raise ValueError ("Unknown object type provided as part of apply() call" )
662
+ data_sources_set_to_update = {
663
+ ob for ob in objects if isinstance (ob , DataSource )
664
+ }
665
+
666
+ for fv in views_to_update :
667
+ data_sources_set_to_update .add (fv .batch_source )
668
+ if fv .stream_source :
669
+ data_sources_set_to_update .add (fv .stream_source )
670
+
671
+ for rfv in request_views_to_update :
672
+ data_sources_set_to_update .add (rfv .request_data_source )
673
+
674
+ for odfv in odfvs_to_update :
675
+ for v in odfv .input_request_data_sources .values ():
676
+ data_sources_set_to_update .add (v )
677
+
678
+ data_sources_to_update = list (data_sources_set_to_update )
672
679
673
680
# Validate all feature views and make inferences.
674
681
self ._validate_all_feature_views (
0 commit comments