-
Notifications
You must be signed in to change notification settings - Fork 2
/
model_arch_2.dot
85 lines (67 loc) · 2.41 KB
/
model_arch_2.dot
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
digraph DiffusionModel {
rankdir=TB;
node [shape=box, style=filled, color=lightblue];
subgraph cluster_0 {
label = "DDIMDiffusionModel";
style=filled;
color=lightgrey;
input [label="Input (x_start, x_cond)"];
noise [label="Noise Generation"];
q_sample [label="q_sample"];
CustomTransformer [label="CustomTransformer"];
p_sample [label="p_sample"];
loss [label="Loss Calculation"];
output [label="Output (x_t_prev or loss)"];
input -> noise;
input -> q_sample;
noise -> q_sample;
q_sample -> CustomTransformer;
CustomTransformer -> p_sample;
CustomTransformer -> loss;
p_sample -> output;
loss -> output;
}
subgraph cluster_1 {
label = "CustomTransformer";
style=filled;
color=lightgrey;
transformer_input [label="Input (x_t, x_cond, t)"];
input_proj [label="Input Projection"];
cond_proj [label="Conditional Projection"];
time_emb [label="Time Embedding"];
rope [label="RoPE"];
transformer_layers [label="Transformer Layers"];
output_proj [label="Output Projection"];
transformer_output [label="Output (eps_pred)"];
transformer_input -> input_proj;
transformer_input -> cond_proj;
transformer_input -> time_emb;
input_proj -> rope;
cond_proj -> rope;
rope -> transformer_layers;
time_emb -> transformer_layers;
transformer_layers -> output_proj;
output_proj -> transformer_output;
}
subgraph cluster_2 {
label = "CustomTransformerLayer";
style=filled;
color=lightgrey;
attention [label="MultiheadAttention"];
norm1 [label="LayerNorm 1"];
ff [label="FeedForward"];
norm2 [label="LayerNorm 2"];
attention -> norm1 -> ff -> norm2;
}
subgraph cluster_3 {
label = "TimeEmbedding";
style=filled;
color=lightgrey;
linear1 [label="Linear 1"];
gelu [label="GELU"];
linear2 [label="Linear 2"];
linear1 -> gelu -> linear2;
}
CustomTransformer -> transformer_input [style=dashed, color=red];
transformer_output -> CustomTransformer [style=dashed, color=red];
}