Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Q] How to properly save and load fp8 NumPy arrays? #207

Open
apivovarov opened this issue Sep 25, 2024 · 3 comments
Open

[Q] How to properly save and load fp8 NumPy arrays? #207

apivovarov opened this issue Sep 25, 2024 · 3 comments

Comments

@apivovarov
Copy link
Contributor

apivovarov commented Sep 25, 2024

I would like to save and load an f8m5e2 array. I initially tried using the standard numpy.save() and numpy.load() functions, but loading fails.

.local/lib/python3.10/site-packages/numpy/lib/format.py", line 325, in descr_to_dtype
    return numpy.dtype(descr)
TypeError: data type '<f1' not understood

.local/lib/python3.10/site-packages/numpy/lib/format.py", line 683, in _read_array_header
    raise ValueError(msg.format(d['descr'])) from e
ValueError: descr is not a valid dtype descriptor: '<f1'

I found that I can save and load float8 arrays using a lower-level API (np.tobytes / np.frombuffer), as shown below:

import ml_dtypes
import numpy as np
import json

# Create the array
x = np.array([.2, .4, .6], dtype=ml_dtypes.float8_e5m2)

# Save the array
with open("a.npy", "wb") as f:
    f.write(x.tobytes())

# Save the array's shape and dtype separately
meta = {"shape": x.shape, "dtype": str(x.dtype)}
with open("a_meta.json", "w") as f:
    json.dump(meta, f)

# Load the array
with open("a.npy", "rb") as f:
    data = f.read()

# Load the metadata
with open("a_meta.json", "r") as f:
    meta = json.load(f)

# Reconstruct the array
x2 = np.frombuffer(data, dtype=getattr(ml_dtypes, meta["dtype"])).reshape(meta["shape"])

print(x2)

Is the solution above (np.tobytes / np.frombuffer) considered best practice for this case?

@jakevdp Jake, can you comment on it?

Related Issues

@apivovarov apivovarov changed the title [Q] How to properly save and load fp8 numpy arrays? [Q] How to properly save and load fp8 NumPy arrays? Sep 25, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 26, 2024

Unfortunately NumPy's array serialization only works with NumPy's built-in dtypes. Probably the easiest way to serialize arrays with custom dtypes is to view them as unsigned int:

import ml_dtypes
import numpy as np
import json

# Create the array
x = np.array([.2, .4, .6], dtype=ml_dtypes.float8_e5m2)

np.save('x.npy', x.view('uint8'))
x2 = np.load('x.npy').view(ml_dtypes.float8_e5m2)

print(np.all(x == x2))
# True

Your approach of serializing the raw bytes also works, though I'd recommend not naming the file with a .npy extension with that approach, because this extension typically implies the file is loadable with np.load.

@apivovarov
Copy link
Contributor Author

Hi Jake, Thank you for your reply!

I have one additional question

I tired to use pickle.
It works. File size is almost the same as default np.save approach

import ml_dtypes
import numpy as np
import pickle

# Create the array
a = np.array([.2, .4, .6], dtype=ml_dtypes.float8_e5m2)
b = np.array([.2, .4, .6], dtype=ml_dtypes.float8_e4m3)

# Save
with open('a.npy.pkl', "wb") as f:
  pickle.dump(a, f)

with open('b.npy.pkl', "wb") as f:
  pickle.dump(b, f)

# Load back
a2 = np.load('a.npy.pkl', allow_pickle=True)
b2 = np.load('b.npy.pkl', allow_pickle=True)

print(np.all(a == a2))
print(np.all(b == b2))

Seems that it works out of the box and saves ml_dtypes dtype info into the the file.

What are the disadvantages of using pickle?

Cons which I found:

  • Pickle is not "secure".
  • It is sensitive to numpy version. env with older numpy might not be able to open files saved on envs with newer numpy version. e.g.
>>> b2 = np.load('b.npy.pkl', allow_pickle=True)
Traceback (most recent call last):
  File "/home/user/.local/lib/python3.9/site-packages/numpy/lib/npyio.py", line 441, in load
    return pickle.load(fid, **pickle_kwargs)
ModuleNotFoundError: No module named 'numpy._core'

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 1, 2024

Yes, pickle works, but has downsides. The two you mention are the main issues: unpickling allows for arbitrary code execution, and will often break when used in an environment with different package versions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants