Skip to content

Commit

Permalink
chore(model): refactor model ppl to handler (#1279)
Browse files Browse the repository at this point in the history
refactor model ppl to handler
  • Loading branch information
tianweidut authored Sep 22, 2022
1 parent e031614 commit c2ebecd
Show file tree
Hide file tree
Showing 16 changed files with 18 additions and 21 deletions.
2 changes: 1 addition & 1 deletion client/starwhale/core/model/default_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def _get_cls(src_dir: Path) -> Any:
_mp = src_dir / DefaultYAMLName.MODEL
_model_config = StandaloneModel.load_model_config(_mp)
_handler = _model_config.run.ppl
_handler = _model_config.run.handler

logger.debug(f"try to import {_handler}@{src_dir}...")
_cls = import_cls(src_dir, _handler, PipelineHandler)
Expand Down
18 changes: 9 additions & 9 deletions client/starwhale/core/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ class ModelRunConfig(ASDictMixin):
# TODO: use attr to tune class
def __init__(
self,
ppl: str,
handler: str,
type: str = EvalHandlerType.DEFAULT,
runtime: str = "",
pkg_data: t.Union[t.List[str], None] = None,
exclude_pkg_data: t.Union[t.List[str], None] = None,
envs: t.Union[t.List[str], None] = None,
**kw: t.Any,
):
self.ppl = ppl.strip()
self.handler = handler.strip()
self.typ = type
self.runtime = runtime.strip()
self.pkg_data = pkg_data or []
Expand All @@ -65,14 +65,14 @@ def __init__(
self._do_validate()

def _do_validate(self) -> None:
if not self.ppl:
if not self.handler:
raise FileFormatError("need ppl field")

def __str__(self) -> str:
return f"Model Run Config: ppl -> {self.ppl}"
return f"Model Run Config: ppl -> {self.handler}"

def __repr__(self) -> str:
return f"Model Run Config: ppl -> {self.ppl}, runtime -> {self.runtime}"
return f"Model Run Config: ppl -> {self.handler}, runtime -> {self.runtime}"


class ModelConfig(ASDictMixin):
Expand Down Expand Up @@ -192,7 +192,7 @@ def get_pipeline_handler(
_model_config = cls.load_model_config(_mp)
if _model_config.run.typ == EvalHandlerType.DEFAULT:
return DEFAULT_EVALUATION_PIPELINE
return _model_config.run.ppl
return _model_config.run.handler

@classmethod
def eval_user_handler(
Expand Down Expand Up @@ -227,7 +227,7 @@ def eval_user_handler(
if _model_config.run.typ == EvalHandlerType.DEFAULT:
_module = DEFAULT_EVALUATION_PIPELINE
else:
_module = _model_config.run.ppl
_module = _model_config.run.handler

_yaml_path = str(workdir / DEFAULT_EVALUATION_JOBS_FNAME)

Expand All @@ -236,7 +236,7 @@ def eval_user_handler(
if _model_config.run.typ == EvalHandlerType.DEFAULT:
_ppl = DEFAULT_EVALUATION_PIPELINE
else:
_ppl = _model_config.run.ppl
_ppl = _model_config.run.handler

_new_yaml_path = _run_dir / DEFAULT_EVALUATION_JOBS_FNAME
Parser.generate_job_yaml(_ppl, workdir, _new_yaml_path)
Expand Down Expand Up @@ -404,7 +404,7 @@ def buildImpl(
self._gen_steps,
5,
"generate execute steps",
dict(typ=_model_config.run.typ, ppl=_model_config.run.ppl),
dict(typ=_model_config.run.typ, ppl=_model_config.run.handler),
),
(
self._render_manifest,
Expand Down
2 changes: 1 addition & 1 deletion client/tests/data/model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ config:
- config/hyperparam.json

run:
ppl: mnist.ppl:MNISTInference
handler: mnist.evaluator:MNISTInference
exclude_pkg_data:
- venv
- .git
Expand Down
2 changes: 1 addition & 1 deletion client/tests/data/model/model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ config:
- config/hyperparam.json

run:
ppl: mnist.ppl:MNISTInference
handler: mnist.evaluator:MNISTInference
exclude_pkg_data:
- venv
- .git
Expand Down
2 changes: 1 addition & 1 deletion example/PennFudanPed/model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ name: mask_rcnn
model:
- models/mcrnn.pth
run:
ppl: pfp.ppl:MaskRCnn
handler: pfp.evaluator:MaskRCnn
desc: mask rcnn resnet50 by pytorch
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion example/cifar10/model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ model:
- models/cifar_net.pth

run:
ppl: cifar.ppl:CIFAR10Inference
handler: cifar.evaluator:CIFAR10Inference

desc: cifar10 by pytorch
File renamed without changes.
5 changes: 1 addition & 4 deletions example/mnist/model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ config:
- config/hyperparam.json

run:
ppl: mnist.ppl:MNISTInference
handler: mnist.evaluator:MNISTInference

desc: mnist by pytorch

tag:
- multi_classification
2 changes: 1 addition & 1 deletion example/nmt/model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ model:
- models/vocab_eng-fra.bin

run:
ppl: nmt.ppl:NMTPipeline
handler: nmt.evaluator:NMTPipeline

desc: nmt by pytorch
File renamed without changes.
2 changes: 1 addition & 1 deletion example/speech_command/model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ name: speech_commands_m5
model:
- models/m5.pth
run:
ppl: sc.ppl:M5Inference
handler: sc.evaluator:M5Inference

desc: m5 by pytorch
File renamed without changes.
2 changes: 1 addition & 1 deletion example/text_cls_AG_NEWS/model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ model:
- models/vocab.i

run:
ppl: tcan.ppl:TextClassificationHandler
handler: tcan.evaluator:TextClassificationHandler

desc: TextClassification by pytorch
File renamed without changes.

0 comments on commit c2ebecd

Please sign in to comment.