Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix 3rd-party training service bug #3726

Merged
merged 6 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nni/tools/nnictl/ts_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def register(args):

try:
service_config = {
'node_module_path': info.node_module_path,
'node_class_name': info.node_class_name,
'nodeModulePath': str(info.node_module_path),
'nodeClassName': info.node_class_name,
}
json.dumps(service_config)
except Exception:
Expand Down
168 changes: 72 additions & 96 deletions ts/nni_manager/common/experimentStartupInfo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,43 @@
import * as assert from 'assert';
import * as os from 'os';
import * as path from 'path';
import * as component from '../common/component';

@component.Singleton
class ExperimentStartupInfo {
private readonly API_ROOT_URL: string = '/api/v1/nni';

private experimentId: string = '';
private newExperiment: boolean = true;
private basePort: number = -1;
private initialized: boolean = false;
private logDir: string = '';
private logLevel: string = '';
private readonly: boolean = false;
private dispatcherPipe: string | null = null;
private platform: string = '';
private urlprefix: string = '';

public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string, urlprefix?: string): void {
assert(!this.initialized);
assert(experimentId.trim().length > 0);

const API_ROOT_URL: string = '/api/v1/nni';

let singleton: ExperimentStartupInfo | null = null;

export class ExperimentStartupInfo {

public experimentId: string = '';
public newExperiment: boolean = true;
public basePort: number = -1;
public initialized: boolean = false;
public logDir: string = '';
public logLevel: string = '';
public readonly: boolean = false;
public dispatcherPipe: string | null = null;
public platform: string = '';
public urlprefix: string = '';

constructor(
newExperiment: boolean,
experimentId: string,
basePort: number,
platform: string,
logDir?: string,
logLevel?: string,
readonly?: boolean,
dispatcherPipe?: string,
urlprefix?: string) {
this.newExperiment = newExperiment;
this.experimentId = experimentId;
this.basePort = basePort;
this.initialized = true;
this.platform = platform;

if (logDir !== undefined && logDir.length > 0) {
this.logDir = path.join(path.normalize(logDir), this.getExperimentId());
this.logDir = path.join(path.normalize(logDir), experimentId);
} else {
this.logDir = path.join(os.homedir(), 'nni-experiments', this.getExperimentId());
this.logDir = path.join(os.homedir(), 'nni-experiments', experimentId);
}

if (logLevel !== undefined && logLevel.length > 1) {
Expand All @@ -55,98 +62,67 @@ class ExperimentStartupInfo {
}
}

public getExperimentId(): string {
assert(this.initialized);

return this.experimentId;
}

public getBasePort(): number {
assert(this.initialized);

return this.basePort;
}

public isNewExperiment(): boolean {
assert(this.initialized);

return this.newExperiment;
}

public getPlatform(): string {
assert(this.initialized);

return this.platform;
public get apiRootUrl(): string {
return this.urlprefix === '' ? API_ROOT_URL : `/${this.urlprefix}${API_ROOT_URL}`;
}

public getLogDir(): string {
assert(this.initialized);

return this.logDir;
}

public getLogLevel(): string {
assert(this.initialized);

return this.logLevel;
}

public isReadonly(): boolean {
assert(this.initialized);

return this.readonly;
}

public getDispatcherPipe(): string | null {
assert(this.initialized);
return this.dispatcherPipe;
}

public getAPIRootUrl(): string {
assert(this.initialized);
return this.urlprefix==''?this.API_ROOT_URL:`/${this.urlprefix}${this.API_ROOT_URL}`;
public static getInstance(): ExperimentStartupInfo {
assert(singleton !== null);
return singleton!;
}
}

function getExperimentId(): string {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getExperimentId();
export function getExperimentStartupInfo(): ExperimentStartupInfo {
return ExperimentStartupInfo.getInstance();
}

function getBasePort(): number {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getBasePort();
export function setExperimentStartupInfo(
newExperiment: boolean,
experimentId: string,
basePort: number,
platform: string,
logDir?: string,
logLevel?: string,
readonly?: boolean,
dispatcherPipe?: string,
urlprefix?: string): void {
singleton = new ExperimentStartupInfo(
newExperiment,
experimentId,
basePort,
platform,
logDir,
logLevel,
readonly,
dispatcherPipe,
urlprefix
);
}

function isNewExperiment(): boolean {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isNewExperiment();
export function getExperimentId(): string {
return getExperimentStartupInfo().experimentId;
}

function getPlatform(): string {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getPlatform();
export function getBasePort(): number {
return getExperimentStartupInfo().basePort;
}

function getExperimentStartupInfo(): ExperimentStartupInfo {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo);
export function isNewExperiment(): boolean {
return getExperimentStartupInfo().newExperiment;
}

function setExperimentStartupInfo(
newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string, urlprefix?: string): void {
component.get<ExperimentStartupInfo>(ExperimentStartupInfo)
.setStartupInfo(newExperiment, experimentId, basePort, platform, logDir, logLevel, readonly, dispatcherPipe, urlprefix);
export function getPlatform(): string {
return getExperimentStartupInfo().platform;
}

function isReadonly(): boolean {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isReadonly();
export function isReadonly(): boolean {
return getExperimentStartupInfo().readonly;
}

function getDispatcherPipe(): string | null {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getDispatcherPipe();
export function getDispatcherPipe(): string | null {
return getExperimentStartupInfo().dispatcherPipe;
}

function getAPIRootUrl(): string {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getAPIRootUrl();
export function getAPIRootUrl(): string {
return getExperimentStartupInfo().apiRootUrl;
}

export {
ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment, getPlatform, getExperimentStartupInfo,
setExperimentStartupInfo, isReadonly, getDispatcherPipe, getAPIRootUrl
};
25 changes: 15 additions & 10 deletions ts/nni_manager/common/pythonScript.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,32 @@
import { spawn } from 'child_process';
import { Logger, getLogger } from './log';

const python = process.platform === 'win32' ? 'python.exe' : 'python3';
const logger: Logger = getLogger('pythonScript');

export async function runPythonScript(script: string, logger?: Logger): Promise<string> {
const python: string = process.platform === 'win32' ? 'python.exe' : 'python3';

export async function runPythonScript(script: string, logTag?: string): Promise<string> {
const proc = spawn(python, [ '-c', script ]);

let stdout: string = '';
let stderr: string = '';
proc.stdout.on('data', (data: string) => { stdout += data; });
proc.stderr.on('data', (data: string) => { stderr += data; });

const procPromise = new Promise<void>((resolve, reject) => {
proc.on('error', (err: Error) => { reject(err); });
proc.on('exit', () => { resolve(); });
});
await procPromise;

const stdout = proc.stdout.read().toString();
const stderr = proc.stderr.read().toString();

if (stderr) {
if (logger === undefined) {
logger = getLogger('pythonScript');
if (logTag) {
logger.warning(`Python script [${logTag}] has stderr:`, stderr);
} else {
logger.warning('Python script has stderr.');
logger.warning(' script:', script);
logger.warning(' stderr:', stderr);
}
logger.warning('python script has stderr.');
logger.warning('script:', script);
logger.warning('stderr:', stderr);
}

return stdout;
Expand Down
9 changes: 4 additions & 5 deletions ts/nni_manager/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@ import * as util from 'util';
import * as glob from 'glob';

import { Database, DataStore } from './datastore';
import { ExperimentStartupInfo, getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo';
import { getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo';
import { ExperimentConfig, Manager } from './manager';
import { ExperimentManager } from './experimentManager';
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';

function getExperimentRootDir(): string {
return getExperimentStartupInfo().getLogDir();
return getExperimentStartupInfo().logDir;
}

function getLogDir(): string {
return path.join(getExperimentRootDir(), 'log');
}

function getLogLevel(): string {
return getExperimentStartupInfo().getLogLevel();
return getExperimentStartupInfo().logLevel;
}

function getDefaultDatabaseDir(): string {
Expand Down Expand Up @@ -184,7 +184,6 @@ function generateParamFileName(hyperParameters: HyperParameters): string {
* Must be paired with `cleanupUnitTest()`.
*/
function prepareUnitTest(): void {
Container.snapshot(ExperimentStartupInfo);
Container.snapshot(Database);
Container.snapshot(DataStore);
Container.snapshot(TrainingService);
Expand Down Expand Up @@ -213,7 +212,7 @@ function cleanupUnitTest(): void {
Container.restore(TrainingService);
Container.restore(DataStore);
Container.restore(Database);
Container.restore(ExperimentStartupInfo);
setExperimentStartupInfo(true, 'unittest', 8080, 'unittest', undefined, logLevel);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use 'unittest' as experiment id here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original logic was snapshot ExperimentStartupInfo in prepareUnitTest and restore it here with IOC's utility.
To share ExperimentStartupInfo with 3rd-party training service, it is no longer an IOC container and has to be manually restored.
This line is copied from prepareUnitTest to restore the same state.

Container.restore(ExperimentManager);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import * as path from 'path';
import * as component from '../../../common/component';
import { getLogger, Logger } from '../../../common/log';
import { ExperimentConfig, AmlConfig, flattenConfig } from '../../../common/experimentConfig';
import { ExperimentStartupInfo } from '../../../common/experimentStartupInfo';
import { validateCodeDir } from '../../common/util';
import { AMLClient } from '../aml/amlClient';
import { AMLEnvironmentInformation } from '../aml/amlConfig';
Expand All @@ -29,10 +30,10 @@ export class AMLEnvironmentService extends EnvironmentService {
private experimentRootDir: string;
private config: FlattenAmlConfig;

constructor(experimentRootDir: string, experimentId: string, config: ExperimentConfig) {
constructor(config: ExperimentConfig, info: ExperimentStartupInfo) {
super();
this.experimentId = experimentId;
this.experimentRootDir = experimentRootDir;
this.experimentId = info.experimentId;
this.experimentRootDir = info.logDir;
this.config = flattenConfig(config, 'aml');
validateCodeDir(this.config.trialCodeDirectory);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,22 @@ import { LocalEnvironmentService } from './localEnvironmentService';
import { RemoteEnvironmentService } from './remoteEnvironmentService';
import { EnvironmentService } from '../environment';
import { ExperimentConfig } from '../../../common/experimentConfig';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { ExperimentStartupInfo } from '../../../common/experimentStartupInfo';
import { getCustomEnvironmentServiceConfig } from '../../../common/nniConfig';
import { getExperimentRootDir, importModule } from '../../../common/utils';

import { importModule } from '../../../common/utils';

export async function createEnvironmentService(name: string, config: ExperimentConfig): Promise<EnvironmentService> {
const expId = getExperimentId();
const rootDir = getExperimentRootDir();
const info = ExperimentStartupInfo.getInstance();

switch(name) {
case 'local':
return new LocalEnvironmentService(rootDir, expId, config);
return new LocalEnvironmentService(config, info);
case 'remote':
return new RemoteEnvironmentService(rootDir, expId, config);
return new RemoteEnvironmentService(config, info);
case 'aml':
return new AMLEnvironmentService(rootDir, expId, config);
return new AMLEnvironmentService(config, info);
case 'openpai':
return new OpenPaiEnvironmentService(rootDir, expId, config);
return new OpenPaiEnvironmentService(config, info);
}

const esConfig = await getCustomEnvironmentServiceConfig(name);
Expand All @@ -30,5 +28,5 @@ export async function createEnvironmentService(name: string, config: ExperimentC
}
const esModule = importModule(esConfig.nodeModulePath);
const esClass = esModule[esConfig.nodeClassName] as any;
return new esClass(rootDir, expId, config);
return new esClass(config, info);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import * as tkill from 'tree-kill';
import * as component from '../../../common/component';
import { getLogger, Logger } from '../../../common/log';
import { ExperimentConfig } from '../../../common/experimentConfig';
import { ExperimentStartupInfo } from '../../../common/experimentStartupInfo';
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { isAlive, getNewLine } from '../../../common/utils';
import { execMkdir, runScript, getScriptName, execCopydir } from '../../common/util';
Expand All @@ -21,10 +22,10 @@ export class LocalEnvironmentService extends EnvironmentService {
private experimentRootDir: string;
private experimentId: string;

constructor(experimentRootDir: string, experimentId: string, _config: ExperimentConfig) {
constructor(_config: ExperimentConfig, info: ExperimentStartupInfo) {
super();
this.experimentId = experimentId;
this.experimentRootDir = experimentRootDir;
this.experimentId = info.experimentId;
this.experimentRootDir = info.logDir;
}

public get environmentMaintenceLoopInterval(): number {
Expand Down
Loading