Skip to content

Commit

Permalink
Wrapper rules
Browse files Browse the repository at this point in the history
Signed-off-by: Aditya Kothari <akothari@lyft.com>
  • Loading branch information
kothariaditya committed Jul 23, 2019
1 parent 09df517 commit 9af9c01
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
4 changes: 3 additions & 1 deletion tests/harness/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ py_binary(
"//tests/harness:python-harness-proto",
"//tests/harness/cases:python",
requirement("validate-email"),
requirement("ipaddress")],
requirement("ipaddress"),
requirement("Jinja2"),
requirement("MarkupSafe"),],
main = "harness.py",
visibility = ["//tests/harness:__subpackages__"],
python_version = "PY2",
Expand Down
53 changes: 47 additions & 6 deletions validate/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _has_field(message_pb, property_name):
return property_name in all_fields

def const_template(option_value, f, name):
const_tmpl = """{% if str(o.string) != "" and o.string['const'] %}
const_tmpl = """{% if str(o.string) != "" and o.string.HasField('const') %}
if p.{{ name }} != \"{{ o.string['const'] }}\":
raise ValidationFailed(\"{{ name }} not equal to {{ o.string['const'] }}\")
{% elif str(o.bool) != "" and o.bool['const'] != '' %}
Expand Down Expand Up @@ -483,6 +483,41 @@ def timestamp_template(option_value, f, name):
return Template(timestamp_tmpl).render(o=option_value,f=f,name=name,required_template=required_template, _has_field=_has_field,dur_lit=dur_lit,dur_arr=dur_arr)


def wrapper_template(option_value, field, name):
wrapper_tmpl = """
if p.HasField(\"{{ name }}\"):
{%- if str(option_value.float) != "" -%}
{{- num_template(option_value, field, name + ".value", option_value.float)|indent(8,True) -}}
{%- endif -%}
{%- if str(option_value.double) != "" -%}
{{- num_template(option_value, field, name + ".value", option_value.double)|indent(8,True) -}}
{%- endif -%}
{%- if str(option_value.int32) != "" -%}
{{- num_template(option_value, field, name + ".value", option_value.int32)|indent(8,True) -}}
{%- endif -%}
{%- if str(option_value.int64) != "" -%}
{{- num_template(option_value, field, name + ".value", option_value.int64)|indent(8,True) -}}
{%- endif -%}
{%- if str(option_value.uint32) != "" -%}
{{- num_template(option_value, field, name + ".value", option_value.uint32)|indent(8,True) -}}
{%- endif -%}
{%- if str(option_value.uint64) != "" -%}
{{- num_template(option_value, field, name + ".value", option_value.uint64)|indent(8,True) -}}
{%- endif -%}
{%- if str(option_value.bool) != "" -%}
{{- bool_template(option_value, field, name + ".value")|indent(8,True) -}}
{%- endif -%}
{%- if str(option_value.string) != "" -%}
{{- string_template(option_value, field, name + ".value")|indent(8,True) -}}
{%- endif -%}
pass
{% if str(option_value.message) != "" and option_value.message['required'] %}
else:
raise ValidationFailed(\"{{ name }} is required.\")
{% endif %}
"""
return Template(wrapper_tmpl).render(option_value = option_value, field = field, name = name, str = str, num_template = num_template, bool_template = bool_template, string_template = string_template)

def rule_type(field, name = ""):
if has_validate(field) and field.message_type is None and not field.containing_oneof:
for option_descriptor, option_value in field.GetOptions().ListFields():
Expand Down Expand Up @@ -526,6 +561,12 @@ def rule_type(field, name = ""):
return duration_template(option_value, field, ".".join([x for x in [name, field.name] if x]))
elif str(option_value.timestamp) is not "":
return timestamp_template(option_value, field, ".".join([x for x in [name, field.name] if x]))
elif str(option_value.float) or str(option_value.int32) or str(option_value.int64) or \
str(option_value.double) or str(option_value.uint32) or str(option_value.uint64) or \
str(option_value.bool) or str(option_value.string):
return wrapper_template(option_value, field, ".".join([x for x in [name, field.name] if x]))
elif str(option_value.bytes):
return "raise UnimplementedException()"
elif str(option_value.message) is not "":
return message_template(option_value, field, ".".join([x for x in [name, field.name] if x]))
else:
Expand All @@ -538,15 +579,15 @@ def rule_type(field, name = ""):

def file_template(proto_message):
file_tmp = """def validate(p):
{% set accessor = p.DESCRIPTOR -%}
{%- set accessor = p.DESCRIPTOR -%}
{% for option_descriptor, option_value in accessor.GetOptions().ListFields() %}
{% if option_descriptor.full_name == "validate.disabled" and option_value %}
return None
{% endif %}
{% endfor %}
{% for field in accessor.fields -%}
{{ rule_type(field) }}
{%- endfor %}
{% for field in accessor.fields %}
{{- rule_type(field) -}}
{% endfor %}
return None"""
return Template(file_tmp).render(rule_type=rule_type, p=proto_message, dir=dir)

Expand All @@ -559,4 +600,4 @@ class ValidationFailed(Exception):

def generate_validate(proto_message):
func = file_template(proto_message)
exec(func); return validate
exec(func); return validate

0 comments on commit 9af9c01

Please sign in to comment.