Skip to content
This repository has been archived by the owner on May 30, 2019. It is now read-only.

Commit

Permalink
Save datasets to cache.
Browse files Browse the repository at this point in the history
And ensure datasets reference propelml.org URLs, not filesystem paths.

Fixes #375
  • Loading branch information
ry committed Mar 24, 2018
1 parent cc1c73a commit 52d25bd
Show file tree
Hide file tree
Showing 9 changed files with 255 additions and 32 deletions.
113 changes: 113 additions & 0 deletions src/cache.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*!
Copyright 2018 Propel http://propel.site/. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

// When people load datasets in Propel, they need to be downloaded from HTTP.
// To avoid making people download the same dataset every time they start
// a training program, we provide a local cache of these datasets.
// The $HOME/.propel/cache directory is where these files will be stored.

import * as path from "path";
import * as rimraf from "rimraf";
import { assert, fetchArrayBuffer, IS_WEB, nodeRequire, URL } from "./util";
import { mkdirp, propelDir } from "./util_node";

export interface Cache {
clearAll(): Promise<void>;
get(url: string): Promise<null | ArrayBuffer>;
set(url: string, ab: ArrayBuffer): Promise<void>;
}

let cacheImpl: Cache;

// TODO move this function to src/fetch.ts
export async function fetchWithCache(url: string): Promise<ArrayBuffer> {
let ab = await cacheImpl.get(url);
if (ab != null) {
return ab;
}
ab = await fetchArrayBuffer(url);
cacheImpl.set(url, ab);
return ab;
}

export function clearAll(): Promise<void> {
return cacheImpl.clearAll();
}

function cacheBase(): string {
return path.resolve(propelDir(), "cache");
}

// Maps a URL to a cache filename. Example:
// "http://propelml.org/data/mnist/train-images-idx3-ubyte.bin"
// "$HOME/.propel/cache/propelml.org/data/mnist/train-images-idx3-ubyte.bin"
export function url2Filename(url: string): string {
// Throw on browser. We expose this method for testing, but only run it on
// Node.
assert(!IS_WEB, "url2Filename is unsupposed in the browser");
const u = new URL(url);
if (!(u.protocol === "http:" || u.protocol === "https:")) {
throw Error(`Unsupported protocol '${u.protocol}'`);
}
if (u.pathname.indexOf("..") >= 0) {
throw Error("Cache name cannot include '..'");
}
// Note we purposely leave the port out of the cache path because
// Windows doesn't allow colons in filenames. This is probably fine
// in 99% of cases and is the simplest solution.
const cacheFn = path.resolve(path.join(cacheBase(), u.hostname, u.pathname));
assert(cacheFn.startsWith(cacheBase()));
return cacheFn;
}

if (IS_WEB) {
// On web do nothing. No caching.
// Maybe use local storage?
cacheImpl = {
async clearAll(): Promise<void> { },
async get(url: string): Promise<null | ArrayBuffer> {
return null;
},
async set(url: string, ab: ArrayBuffer): Promise<void> { },
};
} else {
// Node caching uses the disk.
const fs = nodeRequire("fs");

cacheImpl = {
async clearAll(): Promise<void> {
rimraf.sync(cacheBase());
console.log("Delete cache dir", cacheBase());
},

async get(url: string): Promise<null | ArrayBuffer> {
const cacheFn = url2Filename(url);
if (fs.existsSync(cacheFn)) {
const b = fs.readFileSync(cacheFn, null);
return b.buffer.slice(b.byteOffset,
b.byteOffset + b.byteLength) as ArrayBuffer;
} else {
return null;
}
},

async set(url: string, ab: ArrayBuffer): Promise<void> {
const cacheFn = url2Filename(url);
const cacheDir = path.dirname(cacheFn);
mkdirp(cacheDir);
fs.writeFileSync(cacheFn, Buffer.from(ab));
},
};
}
70 changes: 70 additions & 0 deletions src/cache_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*!
Copyright 2018 Propel http://propel.site/. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

// Node-only test. Browsers have caching built-in.

import * as fs from "fs";
import { test } from "../tools/tester";
import * as cache from "./cache";
import { assert, IS_NODE, nodeRequire } from "./util";
import { isDir } from "./util_node";

// Helper function to start a local web server.
// TODO should be moved to tools/tester eventually.
async function localServer(cb: (url: string) => Promise<void>): Promise<void> {
if (!IS_NODE) {
// We don't need a local server, since we're being hosted from one already.
await cb(`http://${document.location.host}/`);
} else {
const root = __dirname + "/../build/dev_website";
assert(isDir(root), root +
" does not exist. Run ./tools/dev_website before running this test.");
const { createServer } = nodeRequire("http-server");
const server = createServer({ cors: true, root });
server.listen();
const port = server.server.address().port;
const url = `http://127.0.0.1:${port}/`;
try {
await cb(url);
} finally {
server.close();
}
}
}

if (IS_NODE) {
test(async function cache_url2Filename() {
const actual = cache.url2Filename(
"http://propelml.org/data/mnist/train-images-idx3-ubyte.bin");
const expected0 =
".propel/cache/propelml.org/data/mnist/train-images-idx3-ubyte.bin";
// Split and join done for windows compat.
const { join } = nodeRequire("path");
const expected = join(...expected0.split("/"));
assert(actual.endsWith(expected));
});
}

test(async function cache_fetchWithCache() {
cache.clearAll();
await localServer(async function(url: string) {
url += "/data/mnist/train-images-idx3-ubyte.bin";
const ab = await cache.fetchWithCache(url);
assert(ab.byteLength === 47040016);
if (IS_NODE) {
assert(fs.existsSync(cache.url2Filename(url)));
}
});
});
15 changes: 10 additions & 5 deletions src/dataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@

// This module is inspired by TensorFlow's tf.data.Dataset.
// https://www.tensorflow.org/api_docs/python/tf/data/Dataset
import { TextDecoder } from "text-encoding";
import { isUndefined } from "util";
import { stack, tensor, Tensor } from "./api";
import * as cache from "./cache";
import * as mnist from "./mnist";
import { NamedTensors } from "./tensor";
import { assert, delay, fetchStr } from "./util";
import { assert, delay } from "./util";

export function datasetFromSlices(tensors: NamedTensors): Dataset {
return new SliceDataset(tensors);
Expand Down Expand Up @@ -271,7 +273,10 @@ class ShuffleDataset extends Dataset {

async function loadData(fn: string):
Promise<{ features: Tensor, labels: Tensor }> {
const csv = await fetchStr(fn);
const ab = await cache.fetchWithCache(fn);
const dec = new TextDecoder("ascii");
const csv = dec.decode(new Uint8Array(ab));

const lines = csv.trim().split("\n").map(line => line.split(","));
const header = lines.shift();
const nSamples = Number(header.shift());
Expand Down Expand Up @@ -302,17 +307,17 @@ async function loadData(fn: string):
*/
export async function loadIris():
Promise<{ features: Tensor, labels: Tensor }> {
return loadData("deps/data/iris.csv");
return loadData("http://propelml.org/data/iris.csv");
}

export async function loadBreastCancer():
Promise<{ features: Tensor, labels: Tensor }> {
return loadData("deps/data/breast_cancer.csv");
return loadData("http://propelml.org/data/breast_cancer.csv");
}

export async function loadWine():
Promise<{ features: Tensor, labels: Tensor }> {
return loadData("deps/data/wine_data.csv");
return loadData("http://propelml.org/data/wine_data.csv");
}

// TODO
Expand Down
21 changes: 1 addition & 20 deletions src/disk_experiment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,7 @@ import { Experiment, ExperimentOpts, print } from "./experiment";
import * as npy from "./npy";
import { Params, params as createParams } from "./params";
import { assert } from "./util";

/** Returns "$HOME/.propel/" or PROPEL_DIR env var. */
function propelDir(): string {
if (process.env.PROPEL_DIR) {
return process.env.PROPEL_DIR;
} else {
const homeDir = process.platform === "win32" ? process.env.USERPROFILE
: process.env.HOME;
return path.join(homeDir, ".propel/");
}
}
import { isDir, propelDir } from "./util_node";

export class DiskExperiment extends Experiment {
constructor(readonly name: string, opts?: ExperimentOpts) {
Expand Down Expand Up @@ -126,15 +116,6 @@ export class DiskExperiment extends Experiment {
}
}

export function isDir(p: string): boolean {
try {
return fs.statSync(p).isDirectory();
} catch (e) {
if (e.code === "ENOENT") return false;
throw e;
}
}

function filePatternSearch(p: string, pattern: RegExp): string[] {
if (isDir(p)) {
let results = [];
Expand Down
3 changes: 2 additions & 1 deletion src/disk_experiment_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ import * as os from "os";
import * as path from "path";
import * as rimraf from "rimraf";
import { test } from "../tools/tester";
import { DiskExperiment, isDir } from "./disk_experiment";
import { DiskExperiment } from "./disk_experiment";
import { assert, assertAllEqual } from "./tensor_util";
import { isDir } from "./util_node";

function setup() {
process.env.PROPEL_DIR = path.join(os.tmpdir(), "propel_test");
Expand Down
3 changes: 3 additions & 0 deletions src/experiment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ const defaultOpts: ExperimentOpts = {
export async function experiment(name: string,
opts?: ExperimentOpts): Promise<Experiment> {
let exp: Experiment;
if (name === "cache") {
throw Error("Invalid experiment name.");
}
if (IS_NODE) {
const { DiskExperiment } = require("./disk_experiment");
exp = new DiskExperiment(name, opts);
Expand Down
13 changes: 7 additions & 6 deletions src/mnist.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@
limitations under the License.
*/
import { tensor, Tensor } from "./api";
import { assert, fetchArrayBuffer } from "./util";
import * as cache from "./cache";
import { assert } from "./util";

export function filenames(split: string): [string, string] {
if (split === "train") {
return [
"deps/data/mnist/train-labels-idx1-ubyte.bin",
"deps/data/mnist/train-images-idx3-ubyte.bin",
"http://propelml.org/data/mnist/train-labels-idx1-ubyte.bin",
"http://propelml.org/data/mnist/train-images-idx3-ubyte.bin",
];
} else if (split === "test") {
return [
"deps/data/mnist/t10k-labels-idx1-ubyte.bin",
"deps/data/mnist/t10k-images-idx3-ubyte.bin",
"http://propelml.org/data/mnist/t10k-labels-idx1-ubyte.bin",
"http://propelml.org/data/mnist/t10k-images-idx3-ubyte.bin",
];
} else {
throw new Error(`Bad split: ${split}`);
Expand All @@ -48,7 +49,7 @@ export async function loadSplit(split: string):
}

async function loadFile2(href: string) {
const ab = await fetchArrayBuffer(href);
const ab = await cache.fetchWithCache(href);
const i32 = new Int32Array(ab);
const ui8 = new Uint8Array(ab);

Expand Down
48 changes: 48 additions & 0 deletions src/util_node.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*!
Copyright 2018 Propel http://propel.site/. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

import * as fs from "fs";
import * as path from "path";
import { assert } from "./util";

export function isDir(p: string): boolean {
try {
return fs.statSync(p).isDirectory();
} catch (e) {
if (e.code === "ENOENT") return false;
throw e;
}
}

/** Returns "$HOME/.propel/" or PROPEL_DIR env var. */
export function propelDir(): string {
if (process.env.PROPEL_DIR) {
return process.env.PROPEL_DIR;
} else {
const homeDir = process.platform === "win32" ? process.env.USERPROFILE
: process.env.HOME;
return path.join(homeDir, ".propel/");
}
}

/** Recursive mkdir. */
export function mkdirp(dirname: string): void {
if (!isDir(dirname)) {
const parentDir = path.dirname(dirname);
assert(parentDir !== dirname && parentDir.length > 1);
mkdirp(parentDir);
fs.mkdirSync(dirname);
}
}
1 change: 1 addition & 0 deletions tools/test_isomorphic.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import "../src/api_test";
import "../src/backend_test";
import "../src/cache_test";
import "../src/conv_test";
import "../src/dataset_test";
import "../src/example_test";
Expand Down

0 comments on commit 52d25bd

Please sign in to comment.