Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR: Improve calls to libpng-config on Ubuntu/Debian #2398

Merged
merged 2 commits into from
Jul 6, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import io
import re
import sys
import csv
from setuptools import setup, find_packages
from pkg_resources import parse_version, get_distribution, DistributionNotFound
import subprocess
Expand Down Expand Up @@ -135,6 +136,26 @@ def find_library(name, vision_include):
return library_found, conda_installed, include_folder, lib_folder


def get_linux_distribution():
release_data = {}
with open("/etc/os-release") as f:
reader = csv.reader(f, delimiter="=")
for row in reader:
if row:
release_data[row[0]] = row[1]
if release_data["ID"] in ["debian", "raspbian"]:
with open("/etc/debian_version") as f:
debian_version = f.readline().strip()
major_version = debian_version.split(".")[0]
version_split = release_data["VERSION"].split(" ", maxsplit=1)
if version_split[0] == major_version:
# Just major version shown, replace it with the full version
release_data["VERSION"] = " ".join(
[debian_version] + version_split[1:])
print("{} {}".format(release_data["NAME"], release_data["VERSION"]))
return release_data


def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')
Expand Down Expand Up @@ -246,6 +267,14 @@ def get_extensions():
image_library = []
image_link_flags = []

# Detect if build is running under conda/conda-build
conda = distutils.spawn.find_executable('conda')
is_conda = conda is not None

build_prefix = os.environ.get('BUILD_PREFIX', None)
is_conda_build = build_prefix is not None
running_under_conda = is_conda or is_conda_build

# Locating libPNG
libpng = distutils.spawn.find_executable('libpng-config')
pngfix = distutils.spawn.find_executable('pngfix')
Expand All @@ -262,14 +291,26 @@ def get_extensions():
png_version = parse_version(png_version)
if png_version >= parse_version("1.6.0"):
print('Building torchvision with PNG image support')
png_lib = subprocess.run([libpng, '--libdir'],
stdout=subprocess.PIPE)
linux = sys.platform == 'linux'
not_debian = False
libpng_on_conda = False
if linux:
bin_folder = os.path.dirname(sys.executable)
png_bin_folder = os.path.dirname(libpng)
libpng_on_conda = (
running_under_conda and bin_folder == png_bin_folder)
release_info = get_linux_distribution()
not_debian = release_info["NAME"] not in {'Ubuntu', 'Debian'}
if not linux or libpng_on_conda or not_debian:
png_lib = subprocess.run([libpng, '--libdir'],
stdout=subprocess.PIPE)
png_lib = png_lib.stdout.strip().decode('utf-8')
image_library += [png_lib]
png_include = subprocess.run([libpng, '--I_opts'],
stdout=subprocess.PIPE)
png_include = png_include.stdout.strip().decode('utf-8')
_, png_include = png_include.split('-I')
print('libpng include path: {0}'.format(png_include))
image_library += [png_lib.stdout.strip().decode('utf-8')]
image_include += [png_include]
image_link_flags.append('png')
else:
Expand Down