Skip to content

Commit

Permalink
Generated cases for turn network
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiannberg committed Apr 25, 2024
1 parent e8562d4 commit 3f467d0
Show file tree
Hide file tree
Showing 7 changed files with 1,825 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

class PokerDataset(Dataset):

def __init__(self, csv_filename: str, mode: str, param_file: str = None):
def __init__(self, stage: str, csv_filename: str, mode: str, param_file: str = None):
super().__init__()
self.stage = stage
self.current_dir = os.path.dirname(__file__)
self.dataset_path = os.path.join(self.current_dir, "raw", csv_filename)
self.mode = mode
Expand Down Expand Up @@ -91,7 +92,7 @@ def calculate_normalization_params(self):
}

# Save parameters to a file
param_file = os.path.join(self.current_dir, "normalization_params.json")
param_file = os.path.join(self.current_dir, f"{self.stage}_normalization_params.json")
with open(param_file, 'w') as f:
json.dump(self.params, f)

Expand Down

Large diffs are not rendered by default.

Empty file.
24 changes: 12 additions & 12 deletions deepstack_knock_off/src/resolver/neural_network/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,19 @@ class Predictor:
def __init__(self):
self.saved_models_path = os.path.join(os.path.dirname(__file__), "saved_models")
self.river_model = "24-04-2024_19-53-45_epoch_125.pt"
self.normalization_params = os.path.join(os.path.dirname(__file__), "data", "normalization_params.json")
with open(self.normalization_params, 'r') as f:
self.params = json.load(f)

def make_prediction(self, stage, r1, r2, public_cards, pot):
if stage == "river":
model_path = os.path.join(self.saved_models_path, "river", self.river_model)
model = RiverNetwork()
model.load_state_dict(torch.load(model_path))
model.eval()
normalization_params = os.path.join(os.path.dirname(__file__), "data", f"{stage}_normalization_params.json")
with open(normalization_params, 'r') as f:
params = json.load(f)
else:
raise ValueError("{stage} not recognized.")

# One hot encoding for public cards
public_cards_one_hot = np.zeros(24, dtype=np.int8)
for card in public_cards:
Expand All @@ -25,22 +33,14 @@ def make_prediction(self, stage, r1, r2, public_cards, pot):

# Normalize pot
pot = np.array([pot], dtype=np.float64)
pot = (pot - self.params["pot_mean"]) / self.params["pot_std"]
pot = (pot - params["pot_mean"]) / params["pot_std"]

# Create tensors for r1, r2, public cards, pot
r1_tensor = torch.from_numpy(r1).float()
r2_tensor = torch.from_numpy(r2).float()
public_cards_tensor = torch.from_numpy(public_cards_one_hot).float().unsqueeze(0)
pot_tensor = torch.from_numpy(pot).float().unsqueeze(0)

if stage == "river":
model_path = os.path.join(self.saved_models_path, "river", self.river_model)
model = RiverNetwork()
model.load_state_dict(torch.load(model_path))
model.eval()
else:
raise ValueError("{stage} not recognized.")

with torch.no_grad():
predicted_v1, predicted_v2, _ = model(r1_tensor, r2_tensor, public_cards_tensor, pot_tensor)

Expand Down
10 changes: 5 additions & 5 deletions deepstack_knock_off/src/resolver/neural_network/train_river.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ def train():
WEIGHT_DECAY = 0.00001
BATCH_SIZE = 16
CSV_FILENAME = "river_cases_2024-04-22_19-45-01.csv"
NORMALIZATION_PARAMS_PATH = os.path.join(os.path.dirname(__file__), "data", "normalization_params.json")
NORMALIZATION_PARAMS_PATH = os.path.join(os.path.dirname(__file__), "data", "river_normalization_params.json")
SAVED_MODELS_PATH = os.path.join(os.path.dirname(__file__), "saved_models")

print("Creating datasets")
train_dataset = PokerDataset(csv_filename=CSV_FILENAME, mode="train", param_file=NORMALIZATION_PARAMS_PATH)
validation_dataset = PokerDataset(csv_filename=CSV_FILENAME, mode="validation", param_file=NORMALIZATION_PARAMS_PATH)
train_dataset = PokerDataset(stage= "river", csv_filename=CSV_FILENAME, mode="train", param_file=NORMALIZATION_PARAMS_PATH)
validation_dataset = PokerDataset(stage="river", csv_filename=CSV_FILENAME, mode="validation", param_file=NORMALIZATION_PARAMS_PATH)
print(f"Train dataset created with {len(train_dataset)} samples")
print(f"Validation dataset created with {len(validation_dataset)} samples")

Expand Down Expand Up @@ -100,11 +100,11 @@ def test(model_filename: str):
print("\033[1;32m" + "="*15 + " Testing " + "="*15 + "\033[0m")
LOSS_FUNCTION = nn.MSELoss()
CSV_FILENAME = "river_cases_2024-04-22_19-45-01.csv"
NORMALIZATION_PARAMS_PATH = os.path.join(os.path.dirname(__file__), "data", "normalization_params.json")
NORMALIZATION_PARAMS_PATH = os.path.join(os.path.dirname(__file__), "data", "river_normalization_params.json")
SAVED_MODELS_PATH = os.path.join(os.path.dirname(__file__), "saved_models")

print("Creating dataset")
test_dataset = PokerDataset(csv_filename=CSV_FILENAME, mode="test", param_file=NORMALIZATION_PARAMS_PATH)
test_dataset = PokerDataset(stage="river", csv_filename=CSV_FILENAME, mode="test", param_file=NORMALIZATION_PARAMS_PATH)
print(f"Test dataset created with {len(test_dataset)} samples")

print("Creating data loaders")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,4 @@ def generate_turn_cases(num_cases: int):
duration_minutes = duration / 60
print(f"Generating data took {duration_minutes:.2f} minutes to run")

generate_turn_cases(num_cases=100)
generate_turn_cases(num_cases=5000)

0 comments on commit 3f467d0

Please sign in to comment.