Skip to content

Commit

Permalink
fix windows encoding bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
cjld committed Mar 9, 2022
1 parent 53b377e commit 495d78a
Show file tree
Hide file tree
Showing 16 changed files with 113 additions and 59 deletions.
6 changes: 2 additions & 4 deletions python/jittor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

__version__ = '1.3.1.44'
__version__ = '1.3.1.45'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
Expand Down Expand Up @@ -410,9 +410,7 @@ def flatten(input, start_dim=0, end_dim=-1):
return input.reshape(out_shape)
Var.flatten = flatten

def start_grad(x):
return x._update(x)
Var.detach_inplace = Var.start_grad = start_grad
Var.detach_inplace = Var.start_grad

def detach(x):
return x.detach()
Expand Down
5 changes: 5 additions & 0 deletions python/jittor/compile_extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,11 @@ def setup_cub():

def setup_cuda_extern():
if not has_cuda: return
check_ld_path = os.environ.get("LD_LIBRARY_PATH", "")
if "cuda" in check_ld_path.lower() and "lib" in check_ld_path.lower():
LOG.w(f"CUDA related path found in LD_LIBRARY_PATH({check_ld_path}), "
"This path may cause jittor found the wrong libs, "
"please unset LD_LIBRARY_PATH. ")
LOG.vv("setup cuda extern...")
cache_path_cuda = os.path.join(cache_path, "cuda")
cuda_include = os.path.join(jittor_path, "extern", "cuda", "inc")
Expand Down
18 changes: 9 additions & 9 deletions python/jittor/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def gen_jit_tests():
}} // jittor
"""
LOG.vvvv(jit_src)
with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w') as f:
with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w', encoding='utf8') as f:
f.write(jit_src)

def gen_jit_flags():
Expand Down Expand Up @@ -257,7 +257,7 @@ def gen_jit_flags():
}} // jittor
"""
LOG.vvvv(jit_src)
with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w') as f:
with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w', encoding='utf8') as f:
f.write(jit_src)

def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
Expand Down Expand Up @@ -639,9 +639,9 @@ def compile_custom_op(header, source, op_name, warp=True):
make_cache_dir(cops_dir)
hname = os.path.join(cops_dir, op_name+"_op.h")
ccname = os.path.join(cops_dir, op_name+"_op.cc")
with open(hname, 'w') as f:
with open(hname, 'w', encoding='utf8') as f:
f.write(header)
with open(ccname, 'w') as f:
with open(ccname, 'w', encoding='utf8') as f:
f.write(source)
m = compile_custom_ops([hname, ccname])
return getattr(m, op_name)
Expand Down Expand Up @@ -679,7 +679,7 @@ def compile_custom_ops(
dirname = os.path.dirname(name)
if dirname.endswith("inc"):
includes.append(dirname)
with open(name, "r") as f:
with open(name, "r", encoding='utf8') as f:
if "@pyjt" in f.read():
pyjt_includes.append(name)
bname = os.path.basename(name)
Expand Down Expand Up @@ -736,7 +736,7 @@ def insert_anchor(gen_src, anchor_str, insert_str):
"init_module(PyModuleDef* mdef, PyObject* m) {",
f"jittor::pyjt_def_{bname}(m);")

with open(gen_head_fname, "w") as f:
with open(gen_head_fname, "w", encoding='utf8') as f:
f.write(gen_src)

LOG.vvv(f"Build custum ops lib:{gen_lib}")
Expand Down Expand Up @@ -781,7 +781,7 @@ def compile_extern():
files = os.listdir(jittor_path_llvm)
# test_pass.cc is used for test link problem of llvm pass plugin
test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc")
with open(test_pass_path, 'w') as f:
with open(test_pass_path, 'w', encoding='utf8') as f:
f.write("int main() {return 0;}")

# -fno-rtti fix link error
Expand Down Expand Up @@ -1082,7 +1082,7 @@ def fix_cl_flags(cmd):
cc_flags = cc_flags.replace("-lstdc++", "")
cc_flags = cc_flags.replace("-ldl", "")
cc_flags += f" -L\"{py3_link_path}\" -lpython3{sys.version_info.minor} "
cc_flags += " -EHa -MD "
cc_flags += " -EHa -MD -utf-8 "
import jittor_utils
if jittor_utils.msvc_path:
mp = jittor_utils.msvc_path
Expand Down Expand Up @@ -1217,7 +1217,7 @@ def func(x):
op_headers = glob.glob(jittor_path+"/src/ops/**/*op.h", recursive=True)
jit_src = gen_jit_op_maker(op_headers)
LOG.vvvv(jit_src)
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w', encoding='utf8') as f:
f.write(jit_src)
cc_flags += f' -I\"{cache_path}\" -L\"{cache_path}\" -L\"{jit_utils.cache_path}\" '
# gen pyjt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@
" train_loss_list.append(train_loss)\n",
" # 在验证集上进行验证,模型参数不做更新。\n",
" val_loss = val(model, x_val_var, y_val_var, loss_function)\n",
" val_loss_list.append(val_loss)\n",
" val_loss_list.append(val_loss.item())\n",
" \n",
"# 打印训练结束后的模型参数\n",
"print(\"After training: \\n\", model.state_dict())"
Expand Down Expand Up @@ -598,4 +598,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@
" \n",
" plt.imshow(inputs[num].numpy().transpose(1, 2, 0)) # 利用 matplotlib 根据第一个 input 绘制手写数字的图像\n",
" plt.show() # 展示图像\n",
" print(\"target:\", targets[num].data[0]) # 打印第一个 input 数据的真实标签值,即手写数字图像所表达的真实数字\n",
" print(\"target:\", targets[num].numpy()[0]) # 打印第一个 input 数据的真实标签值,即手写数字图像所表达的真实数字\n",
" break"
]
},
Expand Down Expand Up @@ -910,12 +910,12 @@
" outputs = model(inputs) # 通过模型预测手写数字。outputs 中每个数据输出有 10 个分量,对应十个数字的相似度\n",
" loss = loss_function(outputs, targets) # 计算损失函数\n",
" optimizer.step(loss) # 根据损失函数,对模型参数进行优化、更新\n",
" train_losses.append(loss) # 记录该批次的 Loss\n",
" train_losses.append(loss.item()) # 记录该批次的 Loss\n",
" \n",
" if batch_idx % 10 == 0: # 每十个批次,打印一次训练集上的 Loss \n",
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
" epoch, batch_idx, len(train_loader),\n",
" 100. * batch_idx / len(train_loader), loss.data[0]))\n",
" 100. * batch_idx / len(train_loader), loss.item()))\n",
" return train_losses # 返回本纪元的 Loss\n",
"\n",
"\n",
Expand All @@ -926,8 +926,8 @@
" total_num = 0 # 本纪元数据总数\n",
" for batch_idx, (inputs, targets) in enumerate(val_loader): # 通过测试集加载器,按批次迭代数据\n",
" outputs = model(inputs) # 通过模型预测手写数字。outputs 中每个数据输出有 10 个分量,对应十个数字的相似度\n",
" pred = np.argmax(outputs.data, axis=1) # 根据 10 个分量,选择最大相似度的为预测的数字值\n",
" correct = np.sum(targets.data==pred) # 计算本批次中,正确预测的次数,即数据标签等于预测值的数目\n",
" pred = np.argmax(outputs.numpy(), axis=1) # 根据 10 个分量,选择最大相似度的为预测的数字值\n",
" correct = np.sum(targets.numpy()==pred) # 计算本批次中,正确预测的次数,即数据标签等于预测值的数目\n",
" batch_size = inputs.shape[0] # 计算本批次中,数据的总数目\n",
" acc = correct / batch_size # 计算本批次的正确率\n",
" \n",
Expand Down Expand Up @@ -1075,10 +1075,10 @@
" plt.imshow(inputs[num].numpy().transpose(1, 2, 0)) # 绘制该数据的手写数字图像\n",
" plt.show() \n",
" \n",
" print(\"target:\", targets[num].data[0]) # 打印该数据的真实标签值\n",
" print(\"target:\", targets[num].numpy()[0]) # 打印该数据的真实标签值\n",
" \n",
" outputs = model(inputs) # 模型根据输入数据进行预测\n",
" pred = np.argmax(outputs.data, axis=1) # 根据最大相似度得到预测值\n",
" pred = np.argmax(outputs.numpy(), axis=1) # 根据最大相似度得到预测值\n",
" print(\"prediction:\", pred[num]) # 打印该数据的预测值\n",
" break"
]
Expand Down Expand Up @@ -1158,4 +1158,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
2 changes: 1 addition & 1 deletion python/jittor/notebook/md_to_ipynb.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@
ipynb_name = os.path.basename(mdname[:-2])+"ipynb"
ipynb_name = os.path.join(notebook_dir, ipynb_name)
print(mdname, len(src), len(blocks), len(cells), "--->", ipynb_name)
with open(ipynb_name, "w") as f:
with open(ipynb_name, "w", encoding='utf8') as f:
f.write(json.dumps(ipynb))
10 changes: 5 additions & 5 deletions python/jittor/pyjt_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,13 +858,13 @@ def find_bc(i):
def compile_single(head_file_name, src_file_name, src=None):
basename = os.path.basename(head_file_name).split(".")[0]
if src==None:
with open(head_file_name, 'r') as f:
with open(head_file_name, 'r', encoding='utf8') as f:
src = f.read()
code = compile_src(src, head_file_name, basename)
if not code: return False
LOG.vvv("write to", src_file_name)
LOG.vvvv(code)
with open(src_file_name, 'w') as f:
with open(src_file_name, 'w', encoding='utf8') as f:
f.write(code)
return True

Expand All @@ -875,14 +875,14 @@ def compile(cache_path, jittor_path):
basenames = []
pyjt_names = []
for h in headers:
with open(h, 'r') as f:
with open(h, 'r', encoding='utf8') as f:
src = f.read()

bh = os.path.basename(h)
# jit_op_maker.h merge compile with var_holder.h
if bh == "var_holder.h": continue
if bh == "jit_op_maker.h":
with open(os.path.join(jittor_path, "src", "var_holder.h"), "r") as f:
with open(os.path.join(jittor_path, "src", "var_holder.h"), "r", encoding='utf8') as f:
src = f.read() + src
basename = bh.split(".")[0]
fname = "pyjt_"+basename+".cc"
Expand Down Expand Up @@ -913,7 +913,7 @@ def compile(cache_path, jittor_path):
fname = os.path.join(cache_path, "gen", "pyjt_all.cc")
LOG.vvv(("write to", fname))
LOG.vvvv(code)
with open(fname, "w") as f:
with open(fname, "w", encoding='utf8') as f:
f.write(code)
pyjt_names.append(fname)
return pyjt_names
4 changes: 2 additions & 2 deletions python/jittor/script/make_doc.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import os

def fix_config(in_name, out_name, src_path, out_path):
data = open(in_name, 'r').readlines()
data = open(in_name, 'r', encoding='utf8').readlines()
out = []
for d in data:
if d.startswith('INPUT ='):
d = f'INPUT ={src_path}\n'
elif d.startswith('OUTPUT_DIRECTORY ='):
d = f'OUTPUT_DIRECTORY ={out_path}\n'
out.append(d)
f = open(out_name, 'w')
f = open(out_name, 'w', encoding='utf8')
f.writelines(out)

jt_path = os.getcwd()
Expand Down
8 changes: 5 additions & 3 deletions python/jittor/src/jit_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ jit_op_entry_t load_jit_lib(string name, string symbol_name="jit_entry") {
const char* msg = "";
LOGvv << "Opening jit lib:" << name;
#ifdef _WIN32
void* handle = (void*)LoadLibraryExA(name.c_str(), nullptr,
void* handle = (void*)LoadLibraryExA(Utf8ToGbk(name.c_str()).c_str(), nullptr,
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS |
LOAD_LIBRARY_SEARCH_USER_DIRS);
#elif defined(__linux__)
Expand Down Expand Up @@ -206,13 +206,15 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
string jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc");
#ifdef _WIN32
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".dll");
string jit_src_path2 = Utf8ToGbk(jit_src_path.c_str());
#else
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".so");
string& jit_src_path2 = jit_src_path;
#endif
string other_src;
LOGvvv << "Generate" << jit_src_path >> "\n" >> src;
if (rewrite_op || !file_exist(jit_src_path))
write(jit_src_path, src);
if (rewrite_op || !file_exist(jit_src_path2))
write(jit_src_path2, src);
string cmd;

auto symbol_name = get_symbol_name(jit_key);
Expand Down
6 changes: 0 additions & 6 deletions python/jittor/src/pyjt/py_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,7 @@ DEF_IS(string, PyObject*) to_py_object(const string& a) {

DEF_IS(string, string) from_py_object(PyObject* obj) {
Py_ssize_t size;
#ifdef _WIN32
PyObjHolder a(PyUnicode_AsEncodedString(obj, win_encode.c_str(), "strict"));
char* s;
auto ret = PyBytes_AsStringAndSize(a.obj, &s, &size);
#else
const char* s = PyUnicode_AsUTF8AndSize(obj, &size);
#endif
CHECK(s);
return string(s, size);
}
Expand Down
13 changes: 12 additions & 1 deletion python/jittor/src/utils/cache_compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,15 @@ static inline bool is_full_path(const string& name) {
#endif
}

bool cache_compile(string cmd, const string& cache_path, const string& jittor_path) {
bool cache_compile(string cmd, const string& cache_path_, const string& jittor_path_) {
#ifdef _WIN32
cmd = Utf8ToGbk(cmd.c_str());
string cache_path = Utf8ToGbk(cache_path_.c_str());
string jittor_path = Utf8ToGbk(jittor_path_.c_str());
#else
const string& cache_path = cache_path_;
const string& jittor_path = jittor_path_;
#endif
vector<string> input_names;
map<string,vector<string>> extra;
string output_name;
Expand All @@ -255,6 +263,9 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
continue;
processed.insert(input_names[i]);
auto src = read_all(input_names[i]);
#ifdef _WIN32
src = Utf8ToGbk(src.c_str());
#endif
auto back = input_names[i].back();
// *.lib
if (back == 'b') continue;
Expand Down
47 changes: 37 additions & 10 deletions python/jittor/src/utils/log.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,19 +316,13 @@ int register_sigaction() {
return 0;
}

#ifdef _WIN32
string win_encode;
#endif

static int log_init() {
register_sigaction();
std::atexit(log_exiting);
#ifdef _WIN32
if (getenv("JITTOR_ENCODE"))
win_encode = getenv("JITTOR_ENCODE");
else
win_encode = "gbk";
SetConsoleCP(CP_UTF8);
SetConsoleOutputCP(CP_UTF8);
#endif
register_sigaction();
std::atexit(log_exiting);
return 1;
}

Expand Down Expand Up @@ -456,6 +450,39 @@ If you still have problems, please contact us:
}

#ifdef _WIN32

string GbkToUtf8(const char *src_str)
{
int len = MultiByteToWideChar(CP_ACP, 0, src_str, -1, NULL, 0);
wchar_t* wstr = new wchar_t[len + 1];
memset(wstr, 0, len + 1);
MultiByteToWideChar(CP_ACP, 0, src_str, -1, wstr, len);
len = WideCharToMultiByte(CP_UTF8, 0, wstr, -1, NULL, 0, NULL, NULL);
char* str = new char[len + 1];
memset(str, 0, len + 1);
WideCharToMultiByte(CP_UTF8, 0, wstr, -1, str, len, NULL, NULL);
string strTemp = str;
if (wstr) delete[] wstr;
if (str) delete[] str;
return strTemp;
}

string Utf8ToGbk(const char *src_str)
{
int len = MultiByteToWideChar(CP_UTF8, 0, src_str, -1, NULL, 0);
wchar_t* wszGBK = new wchar_t[len + 1];
memset(wszGBK, 0, len * 2 + 2);
MultiByteToWideChar(CP_UTF8, 0, src_str, -1, wszGBK, len);
len = WideCharToMultiByte(CP_ACP, 0, wszGBK, -1, NULL, 0, NULL, NULL);
char* szGBK = new char[len + 1];
memset(szGBK, 0, len + 1);
WideCharToMultiByte(CP_ACP, 0, wszGBK, -1, szGBK, len, NULL, NULL);
string strTemp(szGBK);
if (wszGBK) delete[] wszGBK;
if (szGBK) delete[] szGBK;
return strTemp;
}

int system_popen(const char *cmd, const char* cwd) {
HANDLE g_hChildStd_OUT_Rd = NULL;
HANDLE g_hChildStd_OUT_Wr = NULL;
Expand Down
8 changes: 4 additions & 4 deletions python/jittor/src/utils/log.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ namespace jittor {
// define in tracer.cc
void print_trace();
void breakpoint();
#ifdef _WIN32
string GbkToUtf8(const char *src_str);
string Utf8ToGbk(const char *src_str);
#endif

constexpr int32_t basename_index(const char * const path, const int32_t index = 0, const int32_t slash_index = -1) {
return path[index]
Expand Down Expand Up @@ -277,8 +281,4 @@ bool check_vlog(const char* fileline, int verbose);

void system_with_check(const char* cmd, const char* cwd=nullptr);

#ifdef _WIN32
extern string win_encode;
#endif

} // jittor
Loading

0 comments on commit 495d78a

Please sign in to comment.