Skip to content

Commit f49be16

Browse files
committed
allow to save plot instead of showing it
and add more tests
1 parent 42f2ed9 commit f49be16

File tree

3 files changed

+42
-4
lines changed

3 files changed

+42
-4
lines changed

PackageInfo.g

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ SetPackageInfo( rec(
1010

1111
PackageName := "MachineLearningForCAP",
1212
Subtitle := "Exploring categorical machine learning in CAP",
13-
Version := "2024.07-10",
13+
Version := "2024.07-12",
1414
Date := (function ( ) if IsBound( GAPInfo.SystemEnvironment.GAP_PKG_RELEASE_DATE ) then return GAPInfo.SystemEnvironment.GAP_PKG_RELEASE_DATE; else return Concatenation( ~.Version{[ 1 .. 4 ]}, "-", ~.Version{[ 6, 7 ]}, "-01" ); fi; end)( ),
1515
License := "GPL-2.0-or-later",
1616

examples/Expressions.g

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,37 @@ e := Sin( x1 ) / Cos( x1 ) + Sin( x2 ) ^ 2 + Cos( x2 ) ^ 2;
6060
#! Sin( x1 ) / Cos( x1 ) + Sin( x2 ) ^ 2 + Cos( x2 ) ^ 2
6161
SimplifyExpressionUsingPython( [ e ] );
6262
#! [ "Tan(x1) + 1" ]
63+
Diff( e, 1 )( dummy_input );
64+
#! Sin( x1 ) ^ 2 / Cos( x1 ) ^ 2 + 1
65+
LazyDiff( e, 1 )( dummy_input );;
66+
# Diff( [ "x1", "x2", "x3" ],
67+
# "(((Sin(x1))/(Cos(x1)))+((Sin(x2))^(2)))+((Cos(x2))^(2))", 1 )( [ x1, x2, x3 ] );
68+
JacobianMatrixUsingPython( [ x1*Cos(x2)+Exp(x3), x1*x2*x3 ], [ 1, 2, 3 ] );
69+
#! [ [ "Cos(x2)", "-x1*Sin(x2)", "Exp(x3)" ], [ "x2*x3", "x1*x3", "x1*x2" ] ]
70+
LaTeXOutputUsingPython( e );
71+
#! "\\frac{\\sin{\\left(x_{1} \\right)}}{\\cos{\\left(x_{1} \\right)}}
72+
#! + \\sin^{2}{\\left(x_{2} \\right)} + \\cos^{2}{\\left(x_{2} \\right)}"
73+
sigmoid := Expression( [ "x" ], "Exp(x)/(1+Exp(x))" );
74+
#! Exp( x ) / (1 + Exp( x ))
75+
sigmoid := AsFunction( sigmoid );
76+
#! function( vec ) ... end
77+
sigmoid( [ 0 ] );
78+
#! 0.5
79+
points := List( 0.1 * [ -20 .. 20 ], x -> [ x, sigmoid( [ x ] ) ] );
80+
#! [ [ -2., 0.119203 ], [ -1.9, 0.130108 ], [ -1.8, 0.141851 ], [ -1.7, 0.154465 ],
81+
#! [ -1.6, 0.167982 ], [ -1.5, 0.182426 ], [ -1.4, 0.197816 ], [ -1.3, 0.214165 ],
82+
#! [ -1.2, 0.231475 ], [ -1.1, 0.24974 ], [ -1., 0.268941 ], [ -0.9, 0.28905 ],
83+
#! [ -0.8, 0.310026 ], [ -0.7, 0.331812 ], [ -0.6, 0.354344 ], [ -0.5, 0.377541 ],
84+
#! [ -0.4, 0.401312 ], [ -0.3, 0.425557 ], [ -0.2, 0.450166 ], [ -0.1, 0.475021 ],
85+
#! [ 0., 0.5 ], [ 0.1, 0.524979 ], [ 0.2, 0.549834 ], [ 0.3, 0.574443 ],
86+
#! [ 0.4, 0.598688 ], [ 0.5, 0.622459 ], [ 0.6, 0.645656 ], [ 0.7, 0.668188 ],
87+
#! [ 0.8, 0.689974 ], [ 0.9, 0.71095 ], [ 1., 0.731059 ], [ 1.1, 0.75026 ],
88+
#! [ 1.2, 0.768525 ], [ 1.3, 0.785835 ], [ 1.4, 0.802184 ], [ 1.5, 0.817574 ],
89+
#! [ 1.6, 0.832018 ], [ 1.7, 0.845535 ], [ 1.8, 0.858149 ], [ 1.9, 0.869892 ],
90+
#! [ 2., 0.880797 ] ]
91+
labels := List( points, point -> SelectBasedOnCondition( point[2] < 0.5, 0, 1 ) );
92+
#! [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
93+
#! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ]
94+
ScatterPlotUsingPython( points, labels : size := "100", action := "save" );;
95+
# e.g, dir("/tmp/gaptempdirX7Qsal/")
6396
#! @EndExample

gap/Tools.gi

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ InstallMethod( ScatterPlotUsingPython,
707707
[ IsDenseList, IsDenseList ],
708708

709709
function ( points, labels )
710-
local dir, path, file, size, stream, err, p;
710+
local dir, path, file, size, action, stream, err, p;
711711

712712
dir := DirectoryTemporary( );
713713

@@ -719,6 +719,8 @@ InstallMethod( ScatterPlotUsingPython,
719719

720720
size := CAP_INTERNAL_RETURN_OPTION_OR_DEFAULT( "size", "20" );
721721

722+
action := CAP_INTERNAL_RETURN_OPTION_OR_DEFAULT( "action", "show" );
723+
722724
IO_Write( file,
723725
Concatenation(
724726
"import matplotlib.pyplot as plt\n",
@@ -779,7 +781,10 @@ InstallMethod( ScatterPlotUsingPython,
779781
"plt.ylabel('Y-axis')\n",
780782
"plt.title('Scatter Plot using Matplotlib')\n",
781783
"plt.legend()\n",
782-
"plt.show()\n" ) );
784+
SelectBasedOnCondition(
785+
action = "save",
786+
Concatenation( "plt.savefig('", Filename( dir, "plot.png" ), "', dpi=400)\n" ),
787+
"plt.show()\n" ) ) );
783788

784789
IO_Close( file );
785790

@@ -801,6 +806,6 @@ InstallMethod( ScatterPlotUsingPython,
801806

802807
fi;
803808

804-
return true;
809+
return dir;
805810

806811
end );

0 commit comments

Comments
 (0)