Skip to content

Commit

Permalink
add dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz committed Dec 7, 2024
1 parent 7f3fcdf commit e7fe64a
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 0 deletions.
184 changes: 184 additions & 0 deletions common/dataset/dataset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// Copyright 2024 gorse Project Authors

Check failure on line 1 in common/dataset/dataset.go

View workflow job for this annotation

GitHub Actions / lint

: # github.com/zhenghaoz/gorse/common/dataset [github.com/zhenghaoz/gorse/common/dataset.test]
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package dataset

import (
"archive/zip"
"encoding/csv"
"fmt"
"github.com/zhenghaoz/gorse/base/log"
"go.uber.org/zap"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"strconv"
"strings"
)

var (
tempDir string
datasetDir string
)

func init() {
usr, err := user.Current()
if err != nil {
log.Logger().Fatal("failed to get user directory", zap.Error(err))
}
datasetDir = filepath.Join(usr.HomeDir, ".gorse", "dataset")
tempDir = filepath.Join(usr.HomeDir, ".gorse", "temp")
}

func LoadIris() ([][]float32, []int, error) {
// Download dataset
path, err := downloadAndUnzip("iris")
if err != nil {
return nil, nil, err
}
dataFile := filepath.Join(path, "iris.data")
// Load data
f, err := os.Open(dataFile)
if err != nil {
return nil, nil, err
}
reader := csv.NewReader(f)
rows, err := reader.ReadAll()
if err != nil {
return nil, nil, err
}
// Parse data
data := make([][]float32, len(rows))
target := make([]int, len(rows))
types := make(map[string]int)
for i, row := range rows {
data[i] = make([]float32, 4)
for j, cell := range row[:4] {
data[i][j], err = strconv.ParseFloat(cell, 64)

Check failure on line 70 in common/dataset/dataset.go

View workflow job for this annotation

GitHub Actions / unit tests (Windows)

cannot use strconv.ParseFloat(cell, 64) (value of type float64) as float32 value in assignment

Check failure on line 70 in common/dataset/dataset.go

View workflow job for this annotation

GitHub Actions / unit tests

cannot use strconv.ParseFloat(cell, 64) (value of type float64) as float32 value in assignment

Check failure on line 70 in common/dataset/dataset.go

View workflow job for this annotation

GitHub Actions / unit tests

cannot use strconv.ParseFloat(cell, 64) (value of type float64) as float32 value in assignment

Check failure on line 70 in common/dataset/dataset.go

View workflow job for this annotation

GitHub Actions / unit tests (macOS)

cannot use strconv.ParseFloat(cell, 64) (value of type float64) as float32 value in assignment
if err != nil {
return nil, nil, err
}
}
if _, exist := types[row[4]]; !exist {
types[row[4]] = len(types)
}
target[i] = types[row[4]]
}
return data, target, nil
}

func downloadAndUnzip(name string) (string, error) {
url := fmt.Sprintf("https://pub-64226d9f34c64d6f829f5b63a5540d27.r2.dev/datasets/%s.zip", name)
path := filepath.Join(datasetDir, name)
if _, err := os.Stat(path); os.IsNotExist(err) {
zipFileName, _ := downloadFromUrl(url, tempDir)
if _, err := unzip(zipFileName, path); err != nil {
return "", err
}
}
return path, nil
}

// downloadFromUrl downloads file from URL.
func downloadFromUrl(src, dst string) (string, error) {
log.Logger().Info("Download dataset", zap.String("source", src), zap.String("destination", dst))
// Extract file name
tokens := strings.Split(src, "/")
fileName := filepath.Join(dst, tokens[len(tokens)-1])
// Create file
if err := os.MkdirAll(filepath.Dir(fileName), os.ModePerm); err != nil {
return fileName, err
}
output, err := os.Create(fileName)
if err != nil {
log.Logger().Error("failed to create file", zap.Error(err), zap.String("filename", fileName))
return fileName, err
}
defer output.Close()
// Download file
response, err := http.Get(src)
if err != nil {
log.Logger().Error("failed to download", zap.Error(err), zap.String("source", src))
return fileName, err
}
defer response.Body.Close()
// Save file
_, err = io.Copy(output, response.Body)
if err != nil {
log.Logger().Error("failed to download", zap.Error(err), zap.String("source", src))
return fileName, err
}
return fileName, nil
}

// unzip zip file.
func unzip(src, dst string) ([]string, error) {
var fileNames []string
// Open zip file
r, err := zip.OpenReader(src)
if err != nil {
return fileNames, err
}
defer r.Close()
// Extract files
for _, f := range r.File {
// Open file
rc, err := f.Open()
if err != nil {
return fileNames, err
}
// Store filename/path for returning and using later on
filePath := filepath.Join(dst, f.Name)
// Check for ZipSlip. More Info: http://bit.ly/2MsjAWE
if !strings.HasPrefix(filePath, filepath.Clean(dst)+string(os.PathSeparator)) {
return fileNames, fmt.Errorf("%s: illegal file path", filePath)
}
// Add filename
fileNames = append(fileNames, filePath)
if f.FileInfo().IsDir() {
// Create folder
if err = os.MkdirAll(filePath, os.ModePerm); err != nil {
return fileNames, err
}
} else {
// Create all folders
if err = os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
return fileNames, err
}
// Create file
outFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
if err != nil {
return fileNames, err
}
// Save file
_, err = io.Copy(outFile, rc)
if err != nil {
return nil, err
}
// Close the file without defer to close before next iteration of loop
err = outFile.Close()
if err != nil {
return nil, err
}
}
// Close file
err = rc.Close()
if err != nil {
return nil, err
}
}
return fileNames, nil
}
26 changes: 26 additions & 0 deletions common/dataset/dataset_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package dataset

import (
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/zhenghaoz/gorse/common/nn"
"testing"
)

func TestIris(t *testing.T) {
data, target, err := LoadIris()
assert.NoError(t, err)
_ = data
_ = target

x := nn.NewTensor(lo.Flatten(data), len(data), 4)

Check failure on line 16 in common/dataset/dataset_test.go

View workflow job for this annotation

GitHub Actions / unit tests (Windows)

declared and not used: x

Check failure on line 16 in common/dataset/dataset_test.go

View workflow job for this annotation

GitHub Actions / unit tests

declared and not used: x

Check failure on line 16 in common/dataset/dataset_test.go

View workflow job for this annotation

GitHub Actions / unit tests (macOS)

declared and not used: x

model := nn.NewSequential(
nn.NewLinear(4, 100),
nn.NewReLU(),
nn.NewLinear(100, 100),
nn.NewLinear(100, 3),
nn.NewFlatten(),
)
_ = model
}
14 changes: 14 additions & 0 deletions common/nn/layers.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ func (e *embeddingLayer) Forward(x *Tensor) *Tensor {
return Embedding(e.w, x)
}

type reluLayer struct{}

func NewReLU() Layer {
return &reluLayer{}
}

func (r *reluLayer) Parameters() []*Tensor {
return nil
}

func (r *reluLayer) Forward(x *Tensor) *Tensor {
return ReLu(x)
}

type Sequential struct {
layers []Layer
}
Expand Down

0 comments on commit e7fe64a

Please sign in to comment.