From 73a3fda084625c98b9e1c72ad89645485e7834c4 Mon Sep 17 00:00:00 2001 From: tokoko Date: Mon, 27 May 2024 14:41:57 +0000 Subject: [PATCH 1/2] make registry server writable Signed-off-by: tokoko --- protos/feast/core/Transformation.proto | 2 - protos/feast/registry/RegistryServer.proto | 108 +++++++++++- .../feast/infra/registry/base_registry.py | 8 +- sdk/python/feast/infra/registry/remote.py | 121 ++++++++++++-- sdk/python/feast/registry_server.py | 158 +++++++++++++++++- .../registration/test_universal_registry.py | 122 ++++++++------ 6 files changed, 442 insertions(+), 77 deletions(-) diff --git a/protos/feast/core/Transformation.proto b/protos/feast/core/Transformation.proto index 5cb53e690fa..7033f553f16 100644 --- a/protos/feast/core/Transformation.proto +++ b/protos/feast/core/Transformation.proto @@ -5,8 +5,6 @@ option go_package = "github.com/feast-dev/feast/go/protos/feast/core"; option java_outer_classname = "FeatureTransformationProto"; option java_package = "feast.proto.core"; -import "google/protobuf/duration.proto"; - // Serialized representation of python function. message UserDefinedFunctionV2 { // The function name diff --git a/protos/feast/registry/RegistryServer.proto b/protos/feast/registry/RegistryServer.proto index e99987eb2da..3ca7398fdc1 100644 --- a/protos/feast/registry/RegistryServer.proto +++ b/protos/feast/registry/RegistryServer.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package feast.registry; import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; import "feast/core/Registry.proto"; import "feast/core/Entity.proto"; import "feast/core/DataSource.proto"; @@ -16,16 +17,22 @@ import "feast/core/InfraObject.proto"; service RegistryServer{ // Entity RPCs + rpc ApplyEntity (ApplyEntityRequest) returns (google.protobuf.Empty) {} rpc GetEntity (GetEntityRequest) returns (feast.core.Entity) {} rpc ListEntities (ListEntitiesRequest) returns (ListEntitiesResponse) {} + rpc DeleteEntity (DeleteEntityRequest) returns (google.protobuf.Empty) {} // DataSource RPCs + rpc ApplyDataSource (ApplyDataSourceRequest) returns (google.protobuf.Empty) {} rpc GetDataSource (GetDataSourceRequest) returns (feast.core.DataSource) {} rpc ListDataSources (ListDataSourcesRequest) returns (ListDataSourcesResponse) {} + rpc DeleteDataSource (DeleteDataSourceRequest) returns (google.protobuf.Empty) {} // FeatureView RPCs + rpc ApplyFeatureView (ApplyFeatureViewRequest) returns (google.protobuf.Empty) {} rpc GetFeatureView (GetFeatureViewRequest) returns (feast.core.FeatureView) {} rpc ListFeatureViews (ListFeatureViewsRequest) returns (ListFeatureViewsResponse) {} + rpc DeleteFeatureView (DeleteFeatureViewRequest) returns (google.protobuf.Empty) {} // StreamFeatureView RPCs rpc GetStreamFeatureView (GetStreamFeatureViewRequest) returns (feast.core.StreamFeatureView) {} @@ -36,19 +43,28 @@ service RegistryServer{ rpc ListOnDemandFeatureViews (ListOnDemandFeatureViewsRequest) returns (ListOnDemandFeatureViewsResponse) {} // FeatureService RPCs + rpc ApplyFeatureService (ApplyFeatureServiceRequest) returns (google.protobuf.Empty) {} rpc GetFeatureService (GetFeatureServiceRequest) returns (feast.core.FeatureService) {} rpc ListFeatureServices (ListFeatureServicesRequest) returns (ListFeatureServicesResponse) {} + rpc DeleteFeatureService (DeleteFeatureServiceRequest) returns (google.protobuf.Empty) {} // SavedDataset RPCs + rpc ApplySavedDataset (ApplySavedDatasetRequest) returns (google.protobuf.Empty) {} rpc GetSavedDataset (GetSavedDatasetRequest) returns (feast.core.SavedDataset) {} rpc ListSavedDatasets (ListSavedDatasetsRequest) returns (ListSavedDatasetsResponse) {} + rpc DeleteSavedDataset (DeleteSavedDatasetRequest) returns (google.protobuf.Empty) {} // ValidationReference RPCs + rpc ApplyValidationReference (ApplyValidationReferenceRequest) returns (google.protobuf.Empty) {} rpc GetValidationReference (GetValidationReferenceRequest) returns (feast.core.ValidationReference) {} rpc ListValidationReferences (ListValidationReferencesRequest) returns (ListValidationReferencesResponse) {} - + rpc DeleteValidationReference (DeleteValidationReferenceRequest) returns (google.protobuf.Empty) {} + + rpc ApplyMaterialization (ApplyMaterializationRequest) returns (google.protobuf.Empty) {} rpc ListProjectMetadata (ListProjectMetadataRequest) returns (ListProjectMetadataResponse) {} + rpc UpdateInfra (UpdateInfraRequest) returns (google.protobuf.Empty) {} rpc GetInfra (GetInfraRequest) returns (feast.core.Infra) {} + rpc Commit (google.protobuf.Empty) returns (google.protobuf.Empty) {} rpc Refresh (RefreshRequest) returns (google.protobuf.Empty) {} rpc Proto (google.protobuf.Empty) returns (feast.core.Registry) {} @@ -58,6 +74,12 @@ message RefreshRequest { string project = 1; } +message UpdateInfraRequest { + feast.core.Infra infra = 1; + string project = 2; + bool commit = 3; +} + message GetInfraRequest { string project = 1; bool allow_cache = 2; @@ -72,6 +94,20 @@ message ListProjectMetadataResponse { repeated feast.core.ProjectMetadata project_metadata = 1; } +message ApplyMaterializationRequest { + feast.core.FeatureView feature_view = 1; + string project = 2; + google.protobuf.Timestamp start_date = 3; + google.protobuf.Timestamp end_date = 4; + bool commit = 5; +} + +message ApplyEntityRequest { + feast.core.Entity entity = 1; + string project = 2; + bool commit = 3; +} + message GetEntityRequest { string name = 1; string project = 2; @@ -87,8 +123,20 @@ message ListEntitiesResponse { repeated feast.core.Entity entities = 1; } +message DeleteEntityRequest { + string name = 1; + string project = 2; + bool commit = 3; +} + // DataSources +message ApplyDataSourceRequest { + feast.core.DataSource data_source = 1; + string project = 2; + bool commit = 3; +} + message GetDataSourceRequest { string name = 1; string project = 2; @@ -104,8 +152,24 @@ message ListDataSourcesResponse { repeated feast.core.DataSource data_sources = 1; } +message DeleteDataSourceRequest { + string name = 1; + string project = 2; + bool commit = 3; +} + // FeatureViews +message ApplyFeatureViewRequest { + oneof base_feature_view { + feast.core.FeatureView feature_view = 1; + feast.core.OnDemandFeatureView on_demand_feature_view = 2; + feast.core.StreamFeatureView stream_feature_view = 3; + } + string project = 4; + bool commit = 5; +} + message GetFeatureViewRequest { string name = 1; string project = 2; @@ -121,6 +185,12 @@ message ListFeatureViewsResponse { repeated feast.core.FeatureView feature_views = 1; } +message DeleteFeatureViewRequest { + string name = 1; + string project = 2; + bool commit = 3; +} + // StreamFeatureView message GetStreamFeatureViewRequest { @@ -157,6 +227,12 @@ message ListOnDemandFeatureViewsResponse { // FeatureServices +message ApplyFeatureServiceRequest { + feast.core.FeatureService feature_service = 1; + string project = 2; + bool commit = 3; +} + message GetFeatureServiceRequest { string name = 1; string project = 2; @@ -172,8 +248,20 @@ message ListFeatureServicesResponse { repeated feast.core.FeatureService feature_services = 1; } +message DeleteFeatureServiceRequest { + string name = 1; + string project = 2; + bool commit = 3; +} + // SavedDataset +message ApplySavedDatasetRequest { + feast.core.SavedDataset saved_dataset = 1; + string project = 2; + bool commit = 3; +} + message GetSavedDatasetRequest { string name = 1; string project = 2; @@ -189,8 +277,20 @@ message ListSavedDatasetsResponse { repeated feast.core.SavedDataset saved_datasets = 1; } +message DeleteSavedDatasetRequest { + string name = 1; + string project = 2; + bool commit = 3; +} + // ValidationReference +message ApplyValidationReferenceRequest { + feast.core.ValidationReference validation_reference = 1; + string project = 2; + bool commit = 3; +} + message GetValidationReferenceRequest { string name = 1; string project = 2; @@ -205,3 +305,9 @@ message ListValidationReferencesRequest { message ListValidationReferencesResponse { repeated feast.core.ValidationReference validation_references = 1; } + +message DeleteValidationReferenceRequest { + string name = 1; + string project = 2; + bool commit = 3; +} \ No newline at end of file diff --git a/sdk/python/feast/infra/registry/base_registry.py b/sdk/python/feast/infra/registry/base_registry.py index ed1fc3ab879..b52749a9b2f 100644 --- a/sdk/python/feast/infra/registry/base_registry.py +++ b/sdk/python/feast/infra/registry/base_registry.py @@ -406,18 +406,14 @@ def get_saved_dataset( """ raise NotImplementedError - def delete_saved_dataset(self, name: str, project: str, allow_cache: bool = False): + def delete_saved_dataset(self, name: str, project: str, commit: bool = True): """ Delete a saved dataset. Args: name: Name of dataset project: Feast project that this dataset belongs to - allow_cache: Whether to allow returning this dataset from a cached registry - - Returns: - Returns either the specified SavedDataset, or raises an exception if - none is found + commit: Whether the change should be persisted immediately """ raise NotImplementedError diff --git a/sdk/python/feast/infra/registry/remote.py b/sdk/python/feast/infra/registry/remote.py index f93e1ab1c03..4336db232fb 100644 --- a/sdk/python/feast/infra/registry/remote.py +++ b/sdk/python/feast/infra/registry/remote.py @@ -4,12 +4,12 @@ import grpc from google.protobuf.empty_pb2 import Empty +from google.protobuf.timestamp_pb2 import Timestamp from pydantic import StrictStr from feast.base_feature_view import BaseFeatureView from feast.data_source import DataSource from feast.entity import Entity -from feast.errors import ReadOnlyRegistryException from feast.feature_service import FeatureService from feast.feature_view import FeatureView from feast.infra.infra_object import Infra @@ -43,10 +43,18 @@ def __init__( self.stub = RegistryServer_pb2_grpc.RegistryServerStub(self.channel) def apply_entity(self, entity: Entity, project: str, commit: bool = True): - raise ReadOnlyRegistryException() + request = RegistryServer_pb2.ApplyEntityRequest( + entity=entity.to_proto(), project=project, commit=commit + ) + + self.stub.ApplyEntity(request) def delete_entity(self, name: str, project: str, commit: bool = True): - raise ReadOnlyRegistryException() + request = RegistryServer_pb2.DeleteEntityRequest( + name=name, project=project, commit=commit + ) + + self.stub.DeleteEntity(request) def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: request = RegistryServer_pb2.GetEntityRequest( @@ -69,10 +77,18 @@ def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity] def apply_data_source( self, data_source: DataSource, project: str, commit: bool = True ): - raise ReadOnlyRegistryException() + request = RegistryServer_pb2.ApplyDataSourceRequest( + data_source=data_source.to_proto(), project=project, commit=commit + ) + + self.stub.ApplyDataSource(request) def delete_data_source(self, name: str, project: str, commit: bool = True): - raise ReadOnlyRegistryException() + request = RegistryServer_pb2.DeleteDataSourceRequest( + name=name, project=project, commit=commit + ) + + self.stub.DeleteDataSource(request) def get_data_source( self, name: str, project: str, allow_cache: bool = False @@ -101,10 +117,18 @@ def list_data_sources( def apply_feature_service( self, feature_service: FeatureService, project: str, commit: bool = True ): - raise ReadOnlyRegistryException() + request = RegistryServer_pb2.ApplyFeatureServiceRequest( + feature_service=feature_service.to_proto(), project=project, commit=commit + ) + + self.stub.ApplyFeatureService(request) def delete_feature_service(self, name: str, project: str, commit: bool = True): - raise ReadOnlyRegistryException() + request = RegistryServer_pb2.DeleteFeatureServiceRequest( + name=name, project=project, commit=commit + ) + + self.stub.DeleteFeatureService(request) def get_feature_service( self, name: str, project: str, allow_cache: bool = False @@ -134,10 +158,35 @@ def list_feature_services( def apply_feature_view( self, feature_view: BaseFeatureView, project: str, commit: bool = True ): - raise ReadOnlyRegistryException() + if isinstance(feature_view, StreamFeatureView): + arg_name = "stream_feature_view" + elif isinstance(feature_view, FeatureView): + arg_name = "feature_view" + elif isinstance(feature_view, OnDemandFeatureView): + arg_name = "on_demand_feature_view" + + request = RegistryServer_pb2.ApplyFeatureViewRequest( + feature_view=feature_view.to_proto() + if arg_name == "feature_view" + else None, + stream_feature_view=feature_view.to_proto() + if arg_name == "stream_feature_view" + else None, + on_demand_feature_view=feature_view.to_proto() + if arg_name == "on_demand_feature_view" + else None, + project=project, + commit=commit, + ) + + self.stub.ApplyFeatureView(request) def delete_feature_view(self, name: str, project: str, commit: bool = True): - raise ReadOnlyRegistryException() + request = RegistryServer_pb2.DeleteFeatureViewRequest( + name=name, project=project, commit=commit + ) + + self.stub.DeleteFeatureView(request) def get_stream_feature_view( self, name: str, project: str, allow_cache: bool = False @@ -222,7 +271,20 @@ def apply_materialization( end_date: datetime, commit: bool = True, ): - raise ReadOnlyRegistryException() + start_date_timestamp = Timestamp() + end_date_timestamp = Timestamp() + start_date_timestamp.FromDatetime(start_date) + end_date_timestamp.FromDatetime(end_date) + + request = RegistryServer_pb2.ApplyMaterializationRequest( + feature_view=feature_view.to_proto(), + project=project, + start_date=start_date_timestamp, + end_date=end_date_timestamp, + commit=commit, + ) + + self.stub.ApplyMaterialization(request) def apply_saved_dataset( self, @@ -230,10 +292,18 @@ def apply_saved_dataset( project: str, commit: bool = True, ): - raise ReadOnlyRegistryException() + request = RegistryServer_pb2.ApplySavedDatasetRequest( + saved_dataset=saved_dataset.to_proto(), project=project, commit=commit + ) - def delete_saved_dataset(self, name: str, project: str, allow_cache: bool = False): - raise ReadOnlyRegistryException() + self.stub.ApplyFeatureService(request) + + def delete_saved_dataset(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeleteSavedDatasetRequest( + name=name, project=project, commit=commit + ) + + self.stub.DeleteSavedDataset(request) def get_saved_dataset( self, name: str, project: str, allow_cache: bool = False @@ -266,10 +336,20 @@ def apply_validation_reference( project: str, commit: bool = True, ): - raise ReadOnlyRegistryException() + request = RegistryServer_pb2.ApplyValidationReferenceRequest( + validation_reference=validation_reference.to_proto(), + project=project, + commit=commit, + ) + + self.stub.ApplyValidationReference(request) def delete_validation_reference(self, name: str, project: str, commit: bool = True): - raise ReadOnlyRegistryException() + request = RegistryServer_pb2.DeleteValidationReferenceRequest( + name=name, project=project, commit=commit + ) + + self.stub.DeleteValidationReference(request) def get_validation_reference( self, name: str, project: str, allow_cache: bool = False @@ -308,7 +388,11 @@ def list_project_metadata( return [ProjectMetadata.from_proto(pm) for pm in response.project_metadata] def update_infra(self, infra: Infra, project: str, commit: bool = True): - raise ReadOnlyRegistryException() + request = RegistryServer_pb2.UpdateInfraRequest( + infra=infra.to_proto(), project=project, commit=commit + ) + + self.stub.UpdateInfra(request) def get_infra(self, project: str, allow_cache: bool = False) -> Infra: request = RegistryServer_pb2.GetInfraRequest( @@ -336,9 +420,12 @@ def proto(self) -> RegistryProto: return self.stub.Proto(Empty()) def commit(self): - raise ReadOnlyRegistryException() + self.stub.Commit(Empty()) def refresh(self, project: Optional[str] = None): request = RegistryServer_pb2.RefreshRequest(project=str(project)) self.stub.Refresh(request) + + def teardown(self): + pass diff --git a/sdk/python/feast/registry_server.py b/sdk/python/feast/registry_server.py index 7de0cc43e14..85038ad6ff3 100644 --- a/sdk/python/feast/registry_server.py +++ b/sdk/python/feast/registry_server.py @@ -1,16 +1,34 @@ from concurrent import futures +from datetime import datetime import grpc from google.protobuf.empty_pb2 import Empty from feast import FeatureStore +from feast.data_source import DataSource +from feast.entity import Entity +from feast.feature_service import FeatureService +from feast.feature_view import FeatureView +from feast.infra.infra_object import Infra +from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView from feast.protos.feast.registry import RegistryServer_pb2, RegistryServer_pb2_grpc +from feast.saved_dataset import SavedDataset, ValidationReference +from feast.stream_feature_view import StreamFeatureView class RegistryServer(RegistryServer_pb2_grpc.RegistryServerServicer): - def __init__(self, store: FeatureStore) -> None: + def __init__(self, registry: BaseRegistry) -> None: super().__init__() - self.proxied_registry = store.registry + self.proxied_registry = registry + + def ApplyEntity(self, request: RegistryServer_pb2.ApplyEntityRequest, context): + self.proxied_registry.apply_entity( + entity=Entity.from_proto(request.entity), + project=request.project, + commit=request.commit, + ) + return Empty() def GetEntity(self, request: RegistryServer_pb2.GetEntityRequest, context): return self.proxied_registry.get_entity( @@ -27,6 +45,22 @@ def ListEntities(self, request, context): ] ) + def DeleteEntity(self, request: RegistryServer_pb2.DeleteEntityRequest, context): + self.proxied_registry.delete_entity( + name=request.name, project=request.project, commit=request.commit + ) + return Empty() + + def ApplyDataSource( + self, request: RegistryServer_pb2.ApplyDataSourceRequest, context + ): + self.proxied_registry.apply_data_source( + data_source=DataSource.from_proto(request.data_source), + project=request.project, + commit=request.commit, + ) + return Empty() + def GetDataSource(self, request: RegistryServer_pb2.GetDataSourceRequest, context): return self.proxied_registry.get_data_source( name=request.name, project=request.project, allow_cache=request.allow_cache @@ -42,6 +76,14 @@ def ListDataSources(self, request, context): ] ) + def DeleteDataSource( + self, request: RegistryServer_pb2.DeleteDataSourceRequest, context + ): + self.proxied_registry.delete_data_source( + name=request.name, project=request.project, commit=request.commit + ) + return Empty() + def GetFeatureView( self, request: RegistryServer_pb2.GetFeatureViewRequest, context ): @@ -49,6 +91,24 @@ def GetFeatureView( name=request.name, project=request.project, allow_cache=request.allow_cache ).to_proto() + def ApplyFeatureView( + self, request: RegistryServer_pb2.ApplyFeatureViewRequest, context + ): + feature_view_type = request.WhichOneof("base_feature_view") + if feature_view_type == "feature_view": + feature_view = FeatureView.from_proto(request.feature_view) + elif feature_view_type == "on_demand_feature_view": + feature_view = OnDemandFeatureView.from_proto( + request.on_demand_feature_view + ) + elif feature_view_type == "stream_feature_view": + feature_view = StreamFeatureView.from_proto(request.stream_feature_view) + + self.proxied_registry.apply_feature_view( + feature_view=feature_view, project=request.project, commit=request.commit + ) + return Empty() + def ListFeatureViews(self, request, context): return RegistryServer_pb2.ListFeatureViewsResponse( feature_views=[ @@ -59,6 +119,14 @@ def ListFeatureViews(self, request, context): ] ) + def DeleteFeatureView( + self, request: RegistryServer_pb2.DeleteFeatureViewRequest, context + ): + self.proxied_registry.delete_feature_view( + name=request.name, project=request.project, commit=request.commit + ) + return Empty() + def GetStreamFeatureView( self, request: RegistryServer_pb2.GetStreamFeatureViewRequest, context ): @@ -93,6 +161,16 @@ def ListOnDemandFeatureViews(self, request, context): ] ) + def ApplyFeatureService( + self, request: RegistryServer_pb2.ApplyFeatureServiceRequest, context + ): + self.proxied_registry.apply_feature_service( + feature_service=FeatureService.from_proto(request.feature_service), + project=request.project, + commit=request.commit, + ) + return Empty() + def GetFeatureService( self, request: RegistryServer_pb2.GetFeatureServiceRequest, context ): @@ -112,6 +190,24 @@ def ListFeatureServices( ] ) + def DeleteFeatureService( + self, request: RegistryServer_pb2.DeleteFeatureServiceRequest, context + ): + self.proxied_registry.delete_feature_service( + name=request.name, project=request.project, commit=request.commit + ) + return Empty() + + def ApplySavedDataset( + self, request: RegistryServer_pb2.ApplySavedDatasetRequest, context + ): + self.proxied_registry.apply_saved_dataset( + saved_dataset=SavedDataset.from_proto(request.saved_dataset), + project=request.project, + commit=request.commit, + ) + return Empty() + def GetSavedDataset( self, request: RegistryServer_pb2.GetSavedDatasetRequest, context ): @@ -131,6 +227,26 @@ def ListSavedDatasets( ] ) + def DeleteSavedDataset( + self, request: RegistryServer_pb2.DeleteSavedDatasetRequest, context + ): + self.proxied_registry.delete_saved_dataset( + name=request.name, project=request.project, commit=request.commit + ) + return Empty() + + def ApplyValidationReference( + self, request: RegistryServer_pb2.ApplyValidationReferenceRequest, context + ): + self.proxied_registry.apply_validation_reference( + validation_reference=ValidationReference.from_proto( + request.validation_reference + ), + project=request.project, + commit=request.commit, + ) + return Empty() + def GetValidationReference( self, request: RegistryServer_pb2.GetValidationReferenceRequest, context ): @@ -150,6 +266,14 @@ def ListValidationReferences( ] ) + def DeleteValidationReference( + self, request: RegistryServer_pb2.DeleteValidationReferenceRequest, context + ): + self.proxied_registry.delete_validation_reference( + name=request.name, project=request.project, commit=request.commit + ) + return Empty() + def ListProjectMetadata( self, request: RegistryServer_pb2.ListProjectMetadataRequest, context ): @@ -162,11 +286,39 @@ def ListProjectMetadata( ] ) + def ApplyMaterialization( + self, request: RegistryServer_pb2.ApplyMaterializationRequest, context + ): + self.proxied_registry.apply_materialization( + feature_view=FeatureView.from_proto(request.feature_view), + project=request.project, + start_date=datetime.fromtimestamp( + request.start_date.seconds + request.start_date.nanos / 1e9 + ), + end_date=datetime.fromtimestamp( + request.end_date.seconds + request.end_date.nanos / 1e9 + ), + commit=request.commit, + ) + return Empty() + + def UpdateInfra(self, request: RegistryServer_pb2.UpdateInfraRequest, context): + self.proxied_registry.update_infra( + infra=Infra.from_proto(request.infra), + project=request.project, + commit=request.commit, + ) + return Empty() + def GetInfra(self, request: RegistryServer_pb2.GetInfraRequest, context): return self.proxied_registry.get_infra( project=request.project, allow_cache=request.allow_cache ).to_proto() + def Commit(self, request, context): + self.proxied_registry.commit() + return Empty() + def Refresh(self, request, context): self.proxied_registry.refresh(request.project) return Empty() @@ -178,7 +330,7 @@ def Proto(self, request, context): def start_server(store: FeatureStore, port: int): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) RegistryServer_pb2_grpc.add_RegistryServerServicer_to_server( - RegistryServer(store), server + RegistryServer(store.registry), server ) server.add_insecure_port(f"[::]:{port}") server.start() diff --git a/sdk/python/tests/integration/registration/test_universal_registry.py b/sdk/python/tests/integration/registration/test_universal_registry.py index 1f0ccb4f6b5..18274ae5e19 100644 --- a/sdk/python/tests/integration/registration/test_universal_registry.py +++ b/sdk/python/tests/integration/registration/test_universal_registry.py @@ -18,6 +18,7 @@ from tempfile import mkstemp from unittest import mock +import grpc_testing import pandas as pd import pytest from pytest_lazyfixture import lazy_fixture @@ -36,8 +37,11 @@ from feast.infra.infra_object import Infra from feast.infra.online_stores.sqlite import SqliteTable from feast.infra.registry.registry import Registry +from feast.infra.registry.remote import RemoteRegistry, RemoteRegistryConfig from feast.infra.registry.sql import SqlRegistry from feast.on_demand_feature_view import on_demand_feature_view +from feast.protos.feast.registry import RegistryServer_pb2, RegistryServer_pb2_grpc +from feast.registry_server import RegistryServer from feast.repo_config import RegistryConfig from feast.stream_feature_view import Aggregation, StreamFeatureView from feast.types import Array, Bytes, Float32, Int32, Int64, String @@ -187,19 +191,70 @@ def sqlite_registry(): yield SqlRegistry(registry_config, "project", None) +class GrpcMockChannel: + def __init__(self, service, servicer): + self.service = service + self.test_server = grpc_testing.server_from_dictionary( + {service: servicer}, + grpc_testing.strict_real_time(), + ) + + def unary_unary( + self, method: str, request_serializer=None, response_deserializer=None + ): + method_name = method.split("/")[-1] + method_descriptor = self.service.methods_by_name[method_name] + + def handler(request): + rpc = self.test_server.invoke_unary_unary( + method_descriptor, (), request, None + ) + + response, trailing_metadata, code, details = rpc.termination() + return response + + return handler + + +@pytest.fixture +def mock_remote_registry(): + fd, registry_path = mkstemp() + registry_config = RegistryConfig(path=registry_path, cache_ttl_seconds=600) + proxied_registry = Registry("project", registry_config, None) + + registry = RemoteRegistry( + registry_config=RemoteRegistryConfig(path=""), project=None, repo_path=None + ) + mock_channel = GrpcMockChannel( + RegistryServer_pb2.DESCRIPTOR.services_by_name["RegistryServer"], + RegistryServer(registry=proxied_registry), + ) + registry.stub = RegistryServer_pb2_grpc.RegistryServerStub(mock_channel) + yield registry + + +if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "False": + all_fixtures = ["s3_registry", "gcs_registry"] +else: + all_fixtures = [ + "local_registry", + "minio_registry", + "pg_registry", + "mysql_registry", + "sqlite_registry", + "mock_remote_registry", + ] + + +# sql_fixtures = [ +# "pg_registry", +# "mysql_registry", +# "sqlite_registry", +# ] + + @pytest.mark.integration -@pytest.mark.parametrize( - "test_registry", - [ - lazy_fixture("local_registry"), - lazy_fixture("gcs_registry"), - lazy_fixture("s3_registry"), - lazy_fixture("minio_registry"), - lazy_fixture("pg_registry"), - lazy_fixture("mysql_registry"), - lazy_fixture("sqlite_registry"), - ], -) +@pytest.mark.parametrize("test_registry", [lazy_fixture(f) for f in all_fixtures]) def test_apply_entity_success(test_registry): entity = Entity( name="driver_car_id", @@ -258,15 +313,7 @@ def assert_project_uuid(project, project_uuid, test_registry): @pytest.mark.integration @pytest.mark.parametrize( "test_registry", - [ - lazy_fixture("local_registry"), - lazy_fixture("gcs_registry"), - lazy_fixture("s3_registry"), - lazy_fixture("minio_registry"), - lazy_fixture("pg_registry"), - lazy_fixture("mysql_registry"), - lazy_fixture("sqlite_registry"), - ], + [lazy_fixture(f) for f in all_fixtures], ) def test_apply_feature_view_success(test_registry): # Create Feature Views @@ -361,6 +408,7 @@ def test_apply_feature_view_success(test_registry): lazy_fixture("pg_registry"), lazy_fixture("mysql_registry"), lazy_fixture("sqlite_registry"), + # lazy_fixture("mock_remote_registry"), ], ) def test_apply_on_demand_feature_view_success(test_registry): @@ -443,15 +491,7 @@ def location_features_from_push(inputs: pd.DataFrame) -> pd.DataFrame: @pytest.mark.integration @pytest.mark.parametrize( "test_registry", - [ - lazy_fixture("local_registry"), - lazy_fixture("gcs_registry"), - lazy_fixture("s3_registry"), - lazy_fixture("minio_registry"), - lazy_fixture("pg_registry"), - lazy_fixture("mysql_registry"), - lazy_fixture("sqlite_registry"), - ], + [lazy_fixture(f) for f in all_fixtures], ) def test_apply_data_source(test_registry): # Create Feature Views @@ -514,15 +554,7 @@ def test_apply_data_source(test_registry): @pytest.mark.integration @pytest.mark.parametrize( "test_registry", - [ - lazy_fixture("local_registry"), - lazy_fixture("gcs_registry"), - lazy_fixture("s3_registry"), - lazy_fixture("minio_registry"), - lazy_fixture("pg_registry"), - lazy_fixture("mysql_registry"), - lazy_fixture("sqlite_registry"), - ], + [lazy_fixture(f) for f in all_fixtures], ) def test_modify_feature_views_success(test_registry): # Create Feature Views @@ -653,6 +685,7 @@ def odfv1(feature_df: pd.DataFrame) -> pd.DataFrame: lazy_fixture("pg_registry"), lazy_fixture("mysql_registry"), lazy_fixture("sqlite_registry"), + # lazy_fixture("mock_remote_registry"), ], ) def test_update_infra(test_registry): @@ -692,6 +725,7 @@ def test_update_infra(test_registry): lazy_fixture("pg_registry"), lazy_fixture("mysql_registry"), lazy_fixture("sqlite_registry"), + # lazy_fixture("mock_remote_registry"), ], ) def test_registry_cache(test_registry): @@ -756,15 +790,7 @@ def test_registry_cache(test_registry): @pytest.mark.integration @pytest.mark.parametrize( "test_registry", - [ - lazy_fixture("local_registry"), - lazy_fixture("gcs_registry"), - lazy_fixture("s3_registry"), - lazy_fixture("minio_registry"), - lazy_fixture("pg_registry"), - lazy_fixture("mysql_registry"), - lazy_fixture("sqlite_registry"), - ], + [lazy_fixture(f) for f in all_fixtures], ) def test_apply_stream_feature_view_success(test_registry): # Create Feature Views From c9f3d8404e715f815697ae02462718cf62b90346 Mon Sep 17 00:00:00 2001 From: tokoko Date: Mon, 27 May 2024 14:55:34 +0000 Subject: [PATCH 2/2] fix remote registry tests Signed-off-by: tokoko --- sdk/python/tests/unit/infra/registry/test_remote.py | 2 +- sdk/python/tests/unit/test_registry_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/python/tests/unit/infra/registry/test_remote.py b/sdk/python/tests/unit/infra/registry/test_remote.py index 16c6f0abfb0..8b15f0d507f 100644 --- a/sdk/python/tests/unit/infra/registry/test_remote.py +++ b/sdk/python/tests/unit/infra/registry/test_remote.py @@ -41,7 +41,7 @@ def mock_remote_registry(environment): ) mock_channel = GrpcMockChannel( RegistryServer_pb2.DESCRIPTOR.services_by_name["RegistryServer"], - RegistryServer(store=store), + RegistryServer(registry=store._registry), ) registry.stub = RegistryServer_pb2_grpc.RegistryServerStub(mock_channel) return registry diff --git a/sdk/python/tests/unit/test_registry_server.py b/sdk/python/tests/unit/test_registry_server.py index 734bbfe19b8..462983d8983 100644 --- a/sdk/python/tests/unit/test_registry_server.py +++ b/sdk/python/tests/unit/test_registry_server.py @@ -21,7 +21,7 @@ def call_registry_server(server, method: str, request=None): def registry_server(environment): store: FeatureStore = environment.feature_store - servicer = RegistryServer(store=store) + servicer = RegistryServer(registry=store._registry) return grpc_testing.server_from_dictionary( {RegistryServer_pb2.DESCRIPTOR.services_by_name["RegistryServer"]: servicer},