Skip to content

Commit

Permalink
First round of comments
Browse files Browse the repository at this point in the history
  • Loading branch information
samanvp committed Aug 27, 2018
1 parent 341bf03 commit 67a6340
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 67 deletions.
138 changes: 72 additions & 66 deletions gcp_variant_transforms/beam_io/vcf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import vcf

from apache_beam.coders import coders
from apache_beam.io.filesystems import FileSystems
from apache_beam.io import filesystems
from apache_beam.io import textio


Expand All @@ -46,7 +46,7 @@
DEFAULT_PHASESET_VALUE = '*' # Default phaseset value if call is phased, but
# no 'PS' is present.
MISSING_GENOTYPE_VALUE = -1 # Genotype to use when '.' is used in GT field.

FIRST_HEADER_LINE = '##fileformat=VCFv4.2'

class Variant(object):
"""A class to store info about a genomic variant.
Expand Down Expand Up @@ -494,14 +494,14 @@ def __init__(self,
# This member will be properly initiated in _init_with_header().
self._vcf_reader = None
# These members will be properly initiated in _extract_header_fields().
self._is_info_repeated = {}
self._is_format_repeated = {}
self._header_infos = {}
self._header_formats = {}
# This is a temporary solution until from_string will be fixed.
self._temp_local_file = self._write_to_local_file(file_name)

def _write_to_local_file(self, remote_file_name):
(temp_file, temp_file_name) = tempfile.mkstemp(text=True)
with FileSystems.open(remote_file_name) as f:
with filesystems.FileSystems.open(remote_file_name) as f:
while True:
line = f.readline()
if line:
Expand All @@ -511,73 +511,78 @@ def _write_to_local_file(self, remote_file_name):
break
return temp_file_name

def _store_to_temp_local_file(self, header_lines):
(temp_file, temp_file_name) = tempfile.mkstemp(text=True)
for line in header_lines:
os.write(temp_file, line)
os.close(temp_file)
return temp_file_name

def _init_with_header(self, header_lines):
# This header line is needed by Nucleus.
header_lines = ['##fileformat=VCFv4.2'] + header_lines
# The first header line must be similar to '##fileformat=VCFv4.2'.
if header_lines and not header_lines[0].startswith('##fileformat'):
header_lines = [FIRST_HEADER_LINE] + header_lines
try:
# This is a temporary solution until from_string will be fixed.
self._vcf_reader = nucleus.io.vcf.VcfReader(
self._temp_local_file, use_index=False)

except SyntaxError as e:
self._vcf_reader = nucleus.io.vcf.VcfReader(self._temp_local_file,
use_index=False)
except ValueError as e:
raise ValueError(
'Invalid VCF header in %s: %s' % (self._file_name, str(e)))
self._extract_header_fields()

def _extract_header_fields(self):
header = self._vcf_reader.header
for info in header.infos:
if info.number in ('0', '1'):
self._is_info_repeated[info.id] = False
else:
self._is_info_repeated[info.id] = True
self._header_infos[info.id] = info

for format_info in header.formats:
if format_info.number in ('0', '1'):
self._is_format_repeated[format_info.id] = False
else:
self._is_format_repeated[format_info.id] = True
self._header_formats[format_info.id] = format_info

def _is_info_repeated(self, info_id):
info = self._header_infos.get(info_id, None)
if not info or not info.number:
return False
else:
return self._is_repeated(info.number)

def _is_format_repeated(self, format_id):
format_info = self._header_formats.get(format_id, None)
if not format_info or not format_info.number:
return False
else:
return self._is_repeated(format_info.number)

def _is_repeated(self, number):
if number in ('0', '1'):
return False
else:
return True

def _get_variant(self, data_line):
try:
# This is a temporary solution until from_string will be fixed.
record = next(self._vcf_reader)
return self._convert_to_variant(record)
variant_proto = next(self._vcf_reader)
return self._convert_to_variant(variant_proto)
except (LookupError, ValueError) as e:
logging.warning('VCF record read failed in %s for line %s: %s',
logging.warning('VCF variant_proto read failed in %s for line %s: %s',
self._file_name, data_line, str(e))
return MalformedVcfRecord(self._file_name, data_line, str(e))

def _convert_to_variant(self, record):
# type: (nucleus_proto.Variant) -> Variant
def _convert_to_variant(self, variant_proto):
# type: (nucleus.protos.variants_pb2.Variant) -> Variant
return Variant(
reference_name=record.reference_name,
start=record.start,
end=record.end,
reference_bases=(record.reference_bases
if record.reference_bases != MISSING_FIELD_VALUE
reference_name=variant_proto.reference_name,
start=variant_proto.start,
end=variant_proto.end,
reference_bases=(variant_proto.reference_bases
if variant_proto.reference_bases != MISSING_FIELD_VALUE
else None),
alternate_bases=(
map(str, record.alternate_bases) if record.alternate_bases else []),
names=map(str, record.names) if record.names else [],
quality=record.quality,
filters=(
[PASS_FILTER] if record.filter == [] else map(str, record.filter)),
info=self._get_variant_info(record),
calls=self._get_variant_calls(record))

def _get_variant_info(self, record):
alternate_bases=variant_proto.alternate_bases,
names=map(str, variant_proto.names),
quality=variant_proto.quality,
filters=map(str, variant_proto.filter),
info=self._get_variant_info(variant_proto),
calls=self._get_variant_calls(variant_proto))

def _get_variant_info(self, variant_proto):
info = {}
for k in record.info:
data = self._convert_list_value(record.info[k],
self._is_info_repeated.get(k, False))
for k in variant_proto.info:
data = self._convert_list_value(variant_proto.info[k],
self._is_info_repeated(k))
# Prevents including missing flags as 'false' flag.
if isinstance(data, bool) and not data:
continue
Expand Down Expand Up @@ -610,38 +615,39 @@ def _convert_list_value(self, list_values, is_repeated):

if is_repeated:
return output_list
if not output_list:
return None
if len(output_list) == 1:
return output_list[0]
raise ValueError('a not repeated field has more than 1 value')
else:
if len(output_list) > 1:
raise ValueError('a not repeated field has more than 1 value')
if len(output_list) == 1:
return output_list[0]
else:
return None

def _get_variant_calls(self, record):
def _get_variant_calls(self, variant_proto):
calls = []
for sample in record.calls:
for call_proto in variant_proto.calls:
call = VariantCall()
call.name = sample.call_set_name
if not sample.genotype:
call.name = call_proto.call_set_name
if not call_proto.genotype:
call.genotype.append(MISSING_GENOTYPE_VALUE)
else:
for v in sample.genotype:
call.genotype.append(v)
call.genotype = list(call_proto.genotype)

phaseset_from_format = (
sample.info[PHASESET_FORMAT_KEY].values[0].string_value
if sample.info.get(PHASESET_FORMAT_KEY, None)
self._convert_list_value(call_proto.info[PHASESET_FORMAT_KEY], False)
if call_proto.info.get(PHASESET_FORMAT_KEY, None)
else None)
# Note: Call is considered phased if it contains the 'PS' key regardless
# of whether it uses '|'.
if phaseset_from_format or sample.is_phased:
if phaseset_from_format or call_proto.is_phased:
call.phaseset = (phaseset_from_format if phaseset_from_format
else DEFAULT_PHASESET_VALUE)
for k in sample.info:
for k in call_proto.info:
# Genotype and phaseset (if present) are already included.
if k in (GENOTYPE_FORMAT_KEY, PHASESET_FORMAT_KEY):
continue
is_repeated = self._is_format_repeated.get(k, False)
data = self._convert_list_value(sample.info[k], is_repeated)
data = self._convert_list_value(call_proto.info[k],
self._is_format_repeated(k))
call.info[k] = data
calls.append(call)
return calls
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
'google-api-python-client>=1.6',
'intervaltree>=2.1.0,<2.2.0',
'pyvcf<0.7.0',
'nucleus',
'mmh3<2.6',
# Need to explicitly install v<=1.2.0. apache-beam requires
# google-cloud-pubsub 0.26.0, which relies on google-cloud-core<0.26dev,
Expand Down

0 comments on commit 67a6340

Please sign in to comment.