Skip to content

Commit

Permalink
refactor!: revert returned waveform memory layout (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
kahojyun authored Apr 26, 2024
1 parent 571026f commit cb3e507
Show file tree
Hide file tree
Showing 13 changed files with 135 additions and 152 deletions.
26 changes: 0 additions & 26 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ enum_dispatch = "0.3.13"
float-cmp = "0.9.0"
hashbrown = { version = "0.14.3", features = ["rayon"] }
itertools = "0.12.1"
mimalloc = "0.1.39"
ndarray = { version = "0.15.6", features = ["rayon"] }
num = "0.4.1"
numpy = "0.21.0"
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ import matplotlib.pyplot as plt

from bosing import Barrier, Channel, Hann, Play, Stack, generate_waveforms

channels = [Channel("xy", 30e6, 2e9, 1000)]
shapes = [Hann()]
channels = {"xy": Channel(30e6, 2e9, 1000)}
shapes = {"hann": Hann()}
schedule = Stack(duration=500e-9).with_children(
Play(
channel_id=0,
shape_id=0,
channel_id="xy",
shape_id="hann",
amplitude=0.3,
width=100e-9,
plateau=200e-9,
Expand All @@ -37,8 +37,8 @@ schedule = Stack(duration=500e-9).with_children(
)
result = generate_waveforms(channels, shapes, schedule)
w = result["xy"]
plt.plot(w.real, label="I")
plt.plot(w.imag, label="Q")
plt.plot(w[0], label="I")
plt.plot(w[1], label="Q")
plt.legend()
plt.show()
```
Expand Down
8 changes: 4 additions & 4 deletions example/flexible.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
)
result = generate_waveforms(channels, shapes, schedule)
w = result["xy"]
plt.plot(w.real, label="xy I")
plt.plot(w.imag, label="xy Q")
plt.plot(w[0], label="xy I")
plt.plot(w[1], label="xy Q")
w = result["u"]
plt.plot(w.real, label="u I")
plt.plot(w.imag, label="u Q")
plt.plot(w[0], label="u I")
plt.plot(w[1], label="u Q")
plt.legend()
4 changes: 2 additions & 2 deletions example/hann.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
)
result = generate_waveforms(channels, shapes, schedule)
w = result["xy"]
plt.plot(w.real, label="I")
plt.plot(w.imag, label="Q")
plt.plot(w[0], label="I")
plt.plot(w[1], label="Q")
plt.legend()
4 changes: 2 additions & 2 deletions example/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@
)
result = generate_waveforms(channels, shapes, schedule)
w = result["xy"]
plt.plot(w.real, label="I")
plt.plot(w.imag, label="Q")
plt.plot(w[0], label="I")
plt.plot(w[1], label="Q")
plt.legend()
4 changes: 2 additions & 2 deletions example/overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@
)
result = generate_waveforms(channels, shapes, schedule)
w = result["m"]
plt.plot(w.real, label="I")
plt.plot(w.imag, label="Q")
plt.plot(w[0], label="I")
plt.plot(w[1], label="Q")
plt.legend()
16 changes: 8 additions & 8 deletions example/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@
result = generate_waveforms(channels, shapes, schedule, time_tolerance=1e-13)

t = np.arange(length) / 2e9
plt.plot(t, result["xy0"].real)
plt.plot(t, result["xy0"].imag)
plt.plot(t, result["xy1"].real)
plt.plot(t, result["xy1"].imag)
plt.plot(t, result["u1"].real)
plt.plot(t, result["u1"].imag)
plt.plot(t, result["m0"].real)
plt.plot(t, result["m0"].imag)
plt.plot(t, result["xy0"][0])
plt.plot(t, result["xy0"][1])
plt.plot(t, result["xy1"][0])
plt.plot(t, result["xy1"][1])
plt.plot(t, result["u1"][0])
plt.plot(t, result["u1"][1])
plt.plot(t, result["m0"][0])
plt.plot(t, result["m0"][1])
plt.show()
55 changes: 29 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
//! Although Element struct may contains [`Py<Element>`] as children, it is not
//! possible to create cyclic references because we don't allow mutate the
//! children after creation.
use hashbrown::HashMap;
use mimalloc::MiMalloc;
use numpy::prelude::*;
use numpy::{AllowTypeChange, Complex64, PyArray1, PyArrayLike2};
use pyo3::exceptions::{PyRuntimeError, PyTypeError, PyValueError};
use pyo3::prelude::*;
use std::sync::Arc;

use crate::executor::Executor;
use crate::pulse::Sampler;
use crate::quant::{Frequency, Time};
use schedule::ElementCommonBuilder;
use hashbrown::HashMap;
use numpy::{prelude::*, AllowTypeChange, PyArray2, PyArrayLike2};
use pyo3::{
exceptions::{PyRuntimeError, PyTypeError, PyValueError},
prelude::*,
};

use crate::{
executor::Executor,
pulse::Sampler,
quant::{Frequency, Time},
schedule::ElementCommonBuilder,
};

mod executor;
mod pulse;
mod quant;
mod schedule;
mod shape;

#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;

/// Channel configuration.
///
/// `align_level` is the time axis alignment granularity. With sampling interval
Expand Down Expand Up @@ -1951,7 +1950,8 @@ impl Grid {
/// with corresponding channel ids. Default is ``None``.
/// Returns:
/// Dict[str, numpy.ndarray]: Waveforms of the channels. The key is the
/// channel name and the value is the waveform.
/// channel name and the value is the waveform. The shape of the
/// waveform is ``(2, length)``.
/// Raises:
/// ValueError: If some input is invalid.
/// TypeError: If some input has an invalid type.
Expand Down Expand Up @@ -1994,7 +1994,7 @@ fn generate_waveforms(
amp_tolerance: f64,
allow_oversize: bool,
crosstalk: Option<(PyArrayLike2<'_, f64, AllowTypeChange>, Vec<String>)>,
) -> PyResult<HashMap<String, Py<PyArray1<Complex64>>>> {
) -> PyResult<HashMap<String, Py<PyArray2<f64>>>> {
if let Some((crosstalk, names)) = &crosstalk {
if crosstalk.ndim() != 2 {
return Err(PyValueError::new_err("Crosstalk must be a 2D array."));
Expand All @@ -2008,7 +2008,6 @@ fn generate_waveforms(
));
}
}

let root = schedule.downcast::<Element>()?.get().0.clone();
let measured = schedule::measure(root, f64::INFINITY);
let arrange_options = schedule::ScheduleOptions {
Expand All @@ -2027,13 +2026,17 @@ fn generate_waveforms(
}
executor.execute(&arranged);
let results = executor.into_result();
let mut sampler = Sampler::new();
for (n, pl) in results {
let c = &channels[&n];
let waveforms: HashMap<String, Bound<PyArray2<f64>>> = channels
.iter()
.map(|(n, c)| (n.clone(), PyArray2::zeros_bound(py, (2, c.length), false)))
.collect();
let mut sampler = Sampler::new(results);
for (n, c) in channels {
// SAFETY: These arrays are just created.
let array = unsafe { waveforms[&n].as_array_mut() };
sampler.add_channel(
n,
pl,
c.length,
array,
Frequency::new(c.sample_rate).unwrap(),
Time::new(c.delay).unwrap(),
c.align_level,
Expand All @@ -2042,12 +2045,12 @@ fn generate_waveforms(
if let Some((crosstalk, names)) = &crosstalk {
sampler.set_crosstalk(crosstalk.as_array(), names.clone());
}
let waveforms = sampler.sample(time_tolerance);
let dict = waveforms
sampler.sample(time_tolerance);
let waveforms = waveforms
.into_iter()
.map(|(n, w)| (n, PyArray1::from_vec_bound(py, w).unbind()))
.map(|(n, w)| (n, w.unbind()))
.collect();
Ok(dict)
Ok(waveforms)
}

/// Generates microwave pulses for superconducting quantum computing
Expand Down
Loading

0 comments on commit cb3e507

Please sign in to comment.