Skip to content

Commit

Permalink
Add conditional CUDA support with env variable (#189)
Browse files Browse the repository at this point in the history
* initial GPU support
* Update langkit/transformer.py
* Update langkit/toxicity.py
  • Loading branch information
richard-rogers authored Nov 17, 2023
1 parent 1dffcd4 commit 95b1fb3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
9 changes: 8 additions & 1 deletion langkit/toxicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
from whylogs.experimental.core.udf_schema import register_dataset_udf
from langkit import LangKitConfig, lang_config, prompt_column, response_column

import os
import torch

_USE_CUDA = torch.cuda.is_available() and not bool(
os.environ.get("LANGKIT_NO_CUDA", False)
)
_device = 0 if _USE_CUDA else -1

_prompt = prompt_column
_response = response_column
Expand Down Expand Up @@ -46,7 +53,7 @@ def init(model_path: Optional[str] = None, config: Optional[LangKitConfig] = Non
_toxicity_tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
_toxicity_pipeline = TextClassificationPipeline(
model=model, tokenizer=_toxicity_tokenizer
model=model, tokenizer=_toxicity_tokenizer, device=_device
)


Expand Down
11 changes: 10 additions & 1 deletion langkit/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
from torch import Tensor
import numpy as np

import os
import torch

_USE_CUDA = torch.cuda.is_available() and not bool(
os.environ.get("LANGKIT_NO_CUDA", False)
)
_device = "cuda" if _USE_CUDA else "cpu"


try:
import tensorflow as tf
except ImportError:
Expand Down Expand Up @@ -39,7 +48,7 @@ def __init__(
transformer_model = CustomEncoder(custom_encoder)
self.transformer_name = "custom_encoder"
if transformer_name:
transformer_model = SentenceTransformer(transformer_name)
transformer_model = SentenceTransformer(transformer_name, device=_device)
self.transformer_name = transformer_name
self.transformer_model = transformer_model

Expand Down

0 comments on commit 95b1fb3

Please sign in to comment.