From 06e20dd8bd88665f4260da096e19ee03f4a2aeb1 Mon Sep 17 00:00:00 2001 From: Ryan Dahl Date: Thu, 22 Mar 2018 14:26:09 -0400 Subject: [PATCH] Save datasets to cache. And ensure datasets reference propelml.org URLs, not filesystem paths. Fixes #375 --- src/cache.ts | 108 ++++++++++++++++++++++++++++++++++++ src/cache_test.ts | 63 +++++++++++++++++++++ src/dataset.ts | 15 +++-- src/disk_experiment.ts | 21 +------ src/disk_experiment_test.ts | 3 +- src/experiment.ts | 3 + src/mnist.ts | 13 +++-- src/util_node.ts | 48 ++++++++++++++++ tools/test_isomorphic.ts | 1 + 9 files changed, 243 insertions(+), 32 deletions(-) create mode 100644 src/cache.ts create mode 100644 src/cache_test.ts create mode 100644 src/util_node.ts diff --git a/src/cache.ts b/src/cache.ts new file mode 100644 index 00000000..11f6ab75 --- /dev/null +++ b/src/cache.ts @@ -0,0 +1,108 @@ +/*! + Copyright 2018 Propel http://propel.site/. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +// When people load datasets in Propel, they need to be downloaded from HTTP. +// To avoid making people download the same dataset every time they start +// a training program, we provide a local cache of these datasets. +// The $HOME/.propel/cache directory is where these files will be stored. + +import * as path from "path"; +import * as rimraf from "rimraf"; +import { assert, fetchArrayBuffer, IS_WEB, nodeRequire, URL } from "./util"; +import { mkdirp, propelDir } from "./util_node"; + +export interface Cache { + clearAll(): Promise; + get(url: string): Promise; + set(url: string, ab: ArrayBuffer): Promise; +} + +let cacheImpl: Cache; + +// TODO move this function to src/fetch.ts +export async function fetchWithCache(url: string): Promise { + let ab = await cacheImpl.get(url); + if (ab != null) { + return ab; + } + ab = await fetchArrayBuffer(url); + cacheImpl.set(url, ab); + return ab; +} + +export function clearAll(): Promise { + return cacheImpl.clearAll(); +} + +function cacheBase(): string { + return path.resolve(propelDir(), "cache"); +} + +// Maps a URL to a cache filename. Example: +// "http://propelml.org/data/mnist/train-images-idx3-ubyte.bin" +// "$HOME/.propel/cache/propelml.org/data/mnist/train-images-idx3-ubyte.bin" +export function url2Filename(url: string): string { + if (IS_WEB) return ""; + const u = new URL(url); + if (!(u.protocol === "http:" || u.protocol === "https:")) { + throw Error(`Unsupported protocol '${u.protocol}'`); + } + if (u.pathname.indexOf("..") >= 0) { + throw Error("Cache name cannot include '..'"); + } + const cacheFn = path.resolve(path.join(cacheBase(), u.host, u.pathname)); + assert(cacheFn.startsWith(cacheBase())); + return cacheFn; +} + +if (IS_WEB) { + // On web do nothing. No caching. + // Maybe use local storage? + cacheImpl = { + async clearAll(): Promise { }, + async get(url: string): Promise { + return null; + }, + async set(url: string, ab: ArrayBuffer): Promise { }, + }; +} else { + // Node caching uses the disk. + const fs = nodeRequire("fs"); + + cacheImpl = { + async clearAll(): Promise { + rimraf.sync(cacheBase()); + console.log("Delete cache dir", cacheBase()); + }, + + async get(url: string): Promise { + const cacheFn = url2Filename(url); + if (fs.existsSync(cacheFn)) { + const b = fs.readFileSync(cacheFn, null); + return b.buffer.slice(b.byteOffset, + b.byteOffset + b.byteLength) as ArrayBuffer; + } else { + return null; + } + }, + + async set(url: string, ab: ArrayBuffer): Promise { + const cacheFn = url2Filename(url); + const cacheDir = path.dirname(cacheFn); + mkdirp(cacheDir); + fs.writeFileSync(cacheFn, Buffer.from(ab)); + }, + }; +} diff --git a/src/cache_test.ts b/src/cache_test.ts new file mode 100644 index 00000000..61d39cce --- /dev/null +++ b/src/cache_test.ts @@ -0,0 +1,63 @@ +/*! + Copyright 2018 Propel http://propel.site/. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +// Node-only test. Browsers have caching built-in. + +import * as fs from "fs"; +import { test } from "../tools/tester"; +import * as cache from "./cache"; +import { assert, IS_NODE, nodeRequire } from "./util"; + +// Helper function to start a local web server. +// TODO should be moved to tools/tester eventually. +async function localServer(cb: (url: string) => Promise): Promise { + if (!IS_NODE) { + // We don't need a local server, since we're being hosted from one already. + await cb(`http://${document.location.host}/`); + } else { + const { createServer } = nodeRequire("http-server"); + const server = createServer({ cors: true, root: "./build/dev_website" }); + server.listen(); + const port = server.server.address().port; + const url = `http://127.0.0.1:${port}/`; + try { + await cb(url); + } finally { + server.close(); + } + } +} + +if (IS_NODE) { + test(async function cache_url2Filename() { + const home = process.env.HOME; + const actual = cache.url2Filename( + "http://propelml.org/data/mnist/train-images-idx3-ubyte.bin"); + assert(actual === home + + "/.propel/cache/propelml.org/data/mnist/train-images-idx3-ubyte.bin"); + }); +} + +test(async function cache_fetchWithCache() { + cache.clearAll(); + await localServer(async function(url: string) { + const ab = await cache.fetchWithCache( + url + "/data/mnist/train-images-idx3-ubyte.bin"); + assert(ab.byteLength === 47040016); + if (IS_NODE) { + assert(fs.existsSync(cache.url2Filename(url))); + } + }); +}); diff --git a/src/dataset.ts b/src/dataset.ts index 05969279..353d0e3b 100644 --- a/src/dataset.ts +++ b/src/dataset.ts @@ -15,11 +15,13 @@ // This module is inspired by TensorFlow's tf.data.Dataset. // https://www.tensorflow.org/api_docs/python/tf/data/Dataset +import { TextDecoder } from "text-encoding"; import { isUndefined } from "util"; import { stack, tensor, Tensor } from "./api"; +import * as cache from "./cache"; import * as mnist from "./mnist"; import { NamedTensors } from "./tensor"; -import { assert, delay, fetchStr } from "./util"; +import { assert, delay } from "./util"; export function datasetFromSlices(tensors: NamedTensors): Dataset { return new SliceDataset(tensors); @@ -271,7 +273,10 @@ class ShuffleDataset extends Dataset { async function loadData(fn: string): Promise<{ features: Tensor, labels: Tensor }> { - const csv = await fetchStr(fn); + const ab = await cache.fetchWithCache(fn); + const dec = new TextDecoder("ascii"); + const csv = dec.decode(new Uint8Array(ab)); + const lines = csv.trim().split("\n").map(line => line.split(",")); const header = lines.shift(); const nSamples = Number(header.shift()); @@ -302,17 +307,17 @@ async function loadData(fn: string): */ export async function loadIris(): Promise<{ features: Tensor, labels: Tensor }> { - return loadData("deps/data/iris.csv"); + return loadData("http://propelml.org/data/iris.csv"); } export async function loadBreastCancer(): Promise<{ features: Tensor, labels: Tensor }> { - return loadData("deps/data/breast_cancer.csv"); + return loadData("http://propelml.org/data/breast_cancer.csv"); } export async function loadWine(): Promise<{ features: Tensor, labels: Tensor }> { - return loadData("deps/data/wine_data.csv"); + return loadData("http://propelml.org/data/wine_data.csv"); } // TODO diff --git a/src/disk_experiment.ts b/src/disk_experiment.ts index 5695b21a..86367e0e 100644 --- a/src/disk_experiment.ts +++ b/src/disk_experiment.ts @@ -24,17 +24,7 @@ import { Experiment, ExperimentOpts, print } from "./experiment"; import * as npy from "./npy"; import { Params, params as createParams } from "./params"; import { assert } from "./util"; - -/** Returns "$HOME/.propel/" or PROPEL_DIR env var. */ -function propelDir(): string { - if (process.env.PROPEL_DIR) { - return process.env.PROPEL_DIR; - } else { - const homeDir = process.platform === "win32" ? process.env.USERPROFILE - : process.env.HOME; - return path.join(homeDir, ".propel/"); - } -} +import { isDir, propelDir } from "./util_node"; export class DiskExperiment extends Experiment { constructor(readonly name: string, opts?: ExperimentOpts) { @@ -126,15 +116,6 @@ export class DiskExperiment extends Experiment { } } -export function isDir(p: string): boolean { - try { - return fs.statSync(p).isDirectory(); - } catch (e) { - if (e.code === "ENOENT") return false; - throw e; - } -} - function filePatternSearch(p: string, pattern: RegExp): string[] { if (isDir(p)) { let results = []; diff --git a/src/disk_experiment_test.ts b/src/disk_experiment_test.ts index 59bab130..a3d8ddcf 100644 --- a/src/disk_experiment_test.ts +++ b/src/disk_experiment_test.ts @@ -17,8 +17,9 @@ import * as os from "os"; import * as path from "path"; import * as rimraf from "rimraf"; import { test } from "../tools/tester"; -import { DiskExperiment, isDir } from "./disk_experiment"; +import { DiskExperiment } from "./disk_experiment"; import { assert, assertAllEqual } from "./tensor_util"; +import { isDir } from "./util_node"; function setup() { process.env.PROPEL_DIR = path.join(os.tmpdir(), "propel_test"); diff --git a/src/experiment.ts b/src/experiment.ts index 4862a440..3c82bb9f 100644 --- a/src/experiment.ts +++ b/src/experiment.ts @@ -46,6 +46,9 @@ const defaultOpts: ExperimentOpts = { export async function experiment(name: string, opts?: ExperimentOpts): Promise { let exp: Experiment; + if (name === "cache") { + throw Error("Invalid experiment name."); + } if (IS_NODE) { const { DiskExperiment } = require("./disk_experiment"); exp = new DiskExperiment(name, opts); diff --git a/src/mnist.ts b/src/mnist.ts index 4ec71e02..981ecbf3 100644 --- a/src/mnist.ts +++ b/src/mnist.ts @@ -13,18 +13,19 @@ limitations under the License. */ import { tensor, Tensor } from "./api"; -import { assert, fetchArrayBuffer } from "./util"; +import * as cache from "./cache"; +import { assert } from "./util"; export function filenames(split: string): [string, string] { if (split === "train") { return [ - "deps/data/mnist/train-labels-idx1-ubyte.bin", - "deps/data/mnist/train-images-idx3-ubyte.bin", + "http://propelml.org/data/mnist/train-labels-idx1-ubyte.bin", + "http://propelml.org/data/mnist/train-images-idx3-ubyte.bin", ]; } else if (split === "test") { return [ - "deps/data/mnist/t10k-labels-idx1-ubyte.bin", - "deps/data/mnist/t10k-images-idx3-ubyte.bin", + "http://propelml.org/data/mnist/t10k-labels-idx1-ubyte.bin", + "http://propelml.org/data/mnist/t10k-images-idx3-ubyte.bin", ]; } else { throw new Error(`Bad split: ${split}`); @@ -48,7 +49,7 @@ export async function loadSplit(split: string): } async function loadFile2(href: string) { - const ab = await fetchArrayBuffer(href); + const ab = await cache.fetchWithCache(href); const i32 = new Int32Array(ab); const ui8 = new Uint8Array(ab); diff --git a/src/util_node.ts b/src/util_node.ts new file mode 100644 index 00000000..a7918484 --- /dev/null +++ b/src/util_node.ts @@ -0,0 +1,48 @@ +/*! + Copyright 2018 Propel http://propel.site/. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import * as fs from "fs"; +import * as path from "path"; +import { assert } from "./util"; + +export function isDir(p: string): boolean { + try { + return fs.statSync(p).isDirectory(); + } catch (e) { + if (e.code === "ENOENT") return false; + throw e; + } +} + +/** Returns "$HOME/.propel/" or PROPEL_DIR env var. */ +export function propelDir(): string { + if (process.env.PROPEL_DIR) { + return process.env.PROPEL_DIR; + } else { + const homeDir = process.platform === "win32" ? process.env.USERPROFILE + : process.env.HOME; + return path.join(homeDir, ".propel/"); + } +} + +/** Recursive mkdir. */ +export function mkdirp(dirname: string): void { + if (!isDir(dirname)) { + const parentDir = path.dirname(dirname); + assert(parentDir !== dirname && parentDir.length > 1); + mkdirp(parentDir); + fs.mkdirSync(dirname); + } +} diff --git a/tools/test_isomorphic.ts b/tools/test_isomorphic.ts index d1cd131e..43da8ed9 100644 --- a/tools/test_isomorphic.ts +++ b/tools/test_isomorphic.ts @@ -1,5 +1,6 @@ import "../src/api_test"; import "../src/backend_test"; +import "../src/cache_test"; import "../src/conv_test"; import "../src/dataset_test"; import "../src/example_test";